Compare commits

..

1 Commits

Author SHA1 Message Date
05ce7476ae ggml-ci: update input env variables to GG_BUILD_ 2025-03-14 03:14:44 -05:00
361 changed files with 40394 additions and 60877 deletions

View File

@ -13,6 +13,8 @@ WORKDIR /app
ARG CUDA_DOCKER_ARCH=all
# Set nvcc architecture
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV GGML_CUDA=1
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \
@ -23,8 +25,7 @@ ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY .. .
# Enable cuBLAS
RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1"
RUN make base.en
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
ENV CUDA_MAIN_VERSION=12.3
@ -36,5 +37,4 @@ RUN apt-get update && \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,29 +0,0 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc3.1.1
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the MUSA runtime image
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
# Enable muBLAS
RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1"
FROM ${BASE_MUSA_RUN_CONTAINER} AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -16,5 +16,4 @@ RUN apt-get update && \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,3 +0,0 @@
build*/
.github/
.devops/

View File

@ -1,11 +1,55 @@
name: Bindings Tests (Ruby)
on:
push:
branches:
- master
paths:
- bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/common.h
- examples/common.cpp
- examples/common-whisper.h
- examples/common-whisper.cpp
- examples/stb_vorbis.c
- examples/miniaudio.h
pull_request:
types: [opened, synchronize, reopened]
paths:
- bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/common.h
- examples/common.cpp
- examples/common-whisper.h
- examples/common-whisper.cpp
- examples/stb_vorbis.c
- examples/miniaudio.h
jobs:
ubuntu-22:
@ -16,6 +60,6 @@ jobs:
steps:
- uses: ruby/setup-ruby@v1
with:
ruby-version: '3.2'
ruby-version: '3.1'
- uses: actions/checkout@v4
- run: rake test

View File

@ -6,81 +6,17 @@ on:
- master
pull_request:
types: [opened, synchronize, reopened]
workflow_dispatch:
inputs:
create_release:
description: 'Create new release'
required: true
type: boolean
pre_release_tag:
description: 'Pre-release tag name'
required: false
type: string
run_type:
description: 'Workflow type to run'
required: true
type: choice
options:
- full-ci
- release-only
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
permissions:
contents: write # for creating release
env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
ubuntu_image: "ubuntu:22.04"
VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite"
jobs:
determine-tag:
runs-on: ubuntu-latest
outputs:
tag_name: ${{ steps.tag.outputs.name }}
steps:
- name: Checkout with full history
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER=$(git rev-list --count HEAD)
SHORT_HASH=$(git rev-parse --short=7 HEAD)
CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}"
echo "Raw values:"
echo "BUILD_NUMBER: $BUILD_NUMBER"
echo "SHORT_HASH: $SHORT_HASH"
echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}"
echo "CUSTOM_TAG: $CUSTOM_TAG"
# Use custom tag if provided
if [[ -n "$CUSTOM_TAG" ]]; then
echo "Using custom tag"
TAG_NAME="${CUSTOM_TAG}"
elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "Using master branch format"
TAG_NAME="b${BUILD_NUMBER}"
else
echo "Using non-master branch format"
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}"
fi
echo "Final tag name: $TAG_NAME"
echo "name=$TAG_NAME" >> $GITHUB_OUTPUT
ubuntu-22:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -107,8 +43,6 @@ jobs:
cmake --build build --config Release -j $(nproc)'
ubuntu-22-arm64:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -135,8 +69,6 @@ jobs:
cmake --build build --config Release -j $(nproc)'
ubuntu-22-arm-v7:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -163,8 +95,6 @@ jobs:
cmake --build build --config Release -j $(nproc)'
macOS-latest:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: macOS-latest
strategy:
@ -199,28 +129,31 @@ jobs:
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
- name: xcodebuild for swift package
id: xcodebuild
run: |
./build-xcframework.sh
# freeBSD-latest:
# runs-on: macos-13
# runs-on: macos-12
#
# steps:
# - name: Clone
# uses: actions/checkout@v4
#
# - name: Build
# uses: cross-platform-actions/action@v0.27.0
# uses: cross-platform-actions/action@v0.24.0
# with:
# operating_system: freebsd
# version: '14.2'
# version: '13.3'
# run: |
# sudo pkg update
# sudo pkg install -y gmake sdl2 cmake git
# sudo pkg install -y gmake sdl2 cmake
# cmake -B build
# cmake --build build --config Release
ubuntu-22-gcc:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -249,8 +182,6 @@ jobs:
ctest -L gh --output-on-failure'
ubuntu-22-gcc-arm64:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -279,8 +210,6 @@ jobs:
ctest -L gh --output-on-failure'
ubuntu-22-gcc-arm-v7:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -309,8 +238,6 @@ jobs:
ctest -L gh --output-on-failure'
ubuntu-22-clang:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -342,8 +269,6 @@ jobs:
ctest -L gh --output-on-failure'
ubuntu-22-gcc-sanitized:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -367,15 +292,11 @@ jobs:
set -e
apt update
apt install -y build-essential cmake git
cmake . -DCMAKE_BUILD_TYPE=Debug \
-DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_OPENMP=OFF
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
make
ctest -L gh --output-on-failure'
ubuntu-22-cmake-sycl:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -426,8 +347,6 @@ jobs:
cmake --build . --config Release -j $(nproc)
ubuntu-22-cmake-sycl-fp16:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -478,8 +397,6 @@ jobs:
cmake --build . --config Release -j $(nproc)
windows-msys2:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest
strategy:
@ -524,8 +441,6 @@ jobs:
cmake --build build --config ${{ matrix.build }} -j $(nproc)
windows:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest
strategy:
@ -561,7 +476,6 @@ jobs:
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DBUILD_SHARED_LIBS=ON
-DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build
@ -573,37 +487,12 @@ jobs:
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload SDL2.dll
if: matrix.sdl2 == 'ON'
- name: Upload dll
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.s2arc }}_SDL2.dll
path: build/bin/${{ matrix.build }}/SDL2.dll
- name: Upload whisper dll
uses: actions/upload-artifact@v4
with:
name: whisper_${{ matrix.arch }}.dll
name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload ggml dll
uses: actions/upload-artifact@v4
with:
name: ggml_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml.dll
- name: Upload ggml base dll
uses: actions/upload-artifact@v4
with:
name: ggml_base_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml-base.dll
- name: Upload ggml cpu dll
uses: actions/upload-artifact@v4
with:
name: ggml_cpu_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml-cpu.dll
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4
@ -612,8 +501,6 @@ jobs:
path: build/bin/${{ matrix.build }}
windows-blas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest
strategy:
@ -687,8 +574,6 @@ jobs:
path: build/bin/${{ matrix.build }}
windows-cublas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-2019
strategy:
matrix:
@ -705,134 +590,15 @@ jobs:
- name: Clone repository
uses: actions/checkout@v4
- name: Install Ninja
id: install_ninja
run: |
choco install ninja
- name: Install ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }}
variant: sccache
evict-old-files: 5d
- name: Install Cuda Toolkit 11.8.0
if: ${{ matrix.cuda-toolkit == '11.8.0' }}
run: |
$CUDA_VERSION = ${{ matrix.cuda-toolkit }}
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
# Components versions
$CUDART_VER = "11.8.89"
$NVCC_VER = "11.8.89"
$NVRTC_VER = "11.8.89"
$CUBLAS_VER = "11.8.1.74"
$NVTX_VER = "11.8.86"
$VS_VER = "11.8.86"
$NVPROF_VER = "11.8.87"
$CCCL_VER = "11.8.89"
# Create the directory where the CUDA Toolkit will be installed
mkdir -p $CUDA_TOOLKIT_DIR
# Install unzip to extract the downloaded files
choco install unzip -y
# Download all the required components
curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip"
# Extract all the downloaded files to the CUDA Toolkit directory
unzip '*.zip' -d $CUDA_TOOLKIT_DIR
# Copy all the extracted files to the main CUDA Toolkit directory
xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Install Cuda Toolkit 12.2.0
if: ${{ matrix.cuda-toolkit == '12.2.0' }}
run: |
$CUDA_VERSION = ${{ matrix.cuda-toolkit }}
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
# Components versions
$CUDART_VER = "12.2.140"
$NVCC_VER = "12.2.140"
$NVRTC_VER = "12.2.140"
$CUBLAS_VER = "12.2.5.6"
$NVTX_VER = "12.2.140"
$PROFILER_VER = "12.2.140"
$VS_VER = "12.2.140"
$NVPROF_VER = "12.2.142"
$CCCL_VER = "12.2.140"
# Create the directory where the CUDA Toolkit will be installed
mkdir -p $CUDA_TOOLKIT_DIR
# Install unzip to extract the downloaded files
choco install unzip -y
# Download all the required components
curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip"
# Extract all the downloaded files to the CUDA Toolkit directory
unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR
# Copy all the extracted files to the main CUDA Toolkit directory
xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.15
with:
cuda: '${{ matrix.cuda-toolkit }}'
- name: Install 7-Zip
run: choco install 7zip -y
@ -844,30 +610,25 @@ jobs:
echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt
- name: Install cmake
run: choco install cmake
- name: Configure CMake
shell: cmd
run: |
cmake -S . -B ./build -A ${{ matrix.arch }} ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^
-DCMAKE_CUDA_ARCHITECTURES=all ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%"
- name: Build Project
shell: cmd
run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake --version
where cmake
cmake -S . -B build -G "Ninja Multi-Config" ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%"
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%
- name: Check sccache status after build
run: |
sccache --show-stats
cd ./build
cmake --build . --config ${{ matrix.build }}
- name: Copy CUDA DLLs
run: |
Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" |
Get-ChildItem "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/" -Filter "*.dll" |
Copy-Item -Destination "build/bin/${{ matrix.build }}"
- name: Copy SDL2.dll
@ -881,8 +642,6 @@ jobs:
path: build/bin/${{ matrix.build }}
emscripten:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
strategy:
@ -906,7 +665,6 @@ jobs:
ios-xcode-build:
runs-on: macos-latest
needs: determine-tag
strategy:
matrix:
@ -949,26 +707,7 @@ jobs:
- name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
- name: Pack artifacts
id: pack_artifacts
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
run: |
zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework
- name: Upload artifacts
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4
with:
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
android:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
steps:
@ -997,113 +736,64 @@ jobs:
cd whisper/examples/whisper.android
./gradlew assembleRelease --no-daemon
android_java:
runs-on: ubuntu-22.04
steps:
- name: Clone
uses: actions/checkout@v4
- name: set up JDK 11
uses: actions/setup-java@v4
with:
java-version: '11'
distribution: 'temurin'
cache: gradle
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 9.0
- name: Build
run: |
cd examples/whisper.android.java
chmod +x ./gradlew
./gradlew assembleRelease
bindings-java:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
needs: ['windows']
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- name: Install Java
uses: actions/setup-java@v4
with:
distribution: zulu
java-version: 20
- name: Download Whisper Windows lib
uses: actions/download-artifact@v4
with:
name: whisper_x64.dll
- name: Download GGML Windows lib
uses: actions/download-artifact@v4
with:
name: ggml_x64.dll
- name: Download GGML Base Windows lib
uses: actions/download-artifact@v4
with:
name: ggml_base_x64.dll
- name: Download GGML CPU Windows lib
uses: actions/download-artifact@v4
with:
name: ggml_cpu_x64.dll
- name: Download SDL2.dll
uses: actions/download-artifact@v4
with:
name: x64_SDL2.dll
- name: List downloaded files
shell: pwsh
run: |
Get-ChildItem -Path "." -Recurse -Filter "*.dll"
- name: Move DLL to correct location
shell: pwsh
run: |
New-Item -Path "build\bin\Release" -ItemType Directory -Force
Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force
Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory"
Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force
Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory"
Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force
Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory"
Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force
Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory"
Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force
Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory"
- name: List build release files
shell: pwsh
run: |
Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll"
- name: Build
run: |
models\download-ggml-model.cmd tiny.en models/
cd bindings/java
chmod +x ./gradlew
./gradlew build --info
- name: Upload jar
uses: actions/upload-artifact@v4
with:
name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar
# TODO: disable because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/11019444420/job/30627193602
# android_java:
# runs-on: ubuntu-22.04
#
# steps:
# - name: Clone
# uses: actions/checkout@v4
#
# - name: set up JDK 11
# uses: actions/setup-java@v4
# with:
# java-version: '11'
# distribution: 'temurin'
# cache: gradle
#
# - name: Setup Android SDK
# uses: android-actions/setup-android@v3
# with:
# cmdline-tools-version: 9.0
#
# - name: Build
# run: |
# cd examples/whisper.android.java
# chmod +x ./gradlew
# ./gradlew assembleRelease
# TODO: disabled because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/9686220096/job/26735899598
# java:
# needs: [ 'windows' ]
# runs-on: windows-latest
# steps:
# - uses: actions/checkout@v4
#
# - name: Install Java
# uses: actions/setup-java@v4
# with:
# distribution: zulu
# java-version: 20
#
# - name: Download Windows lib
# uses: actions/download-artifact@v4
# with:
# name: win32-x86-64_whisper.dll
# path: bindings/java/build/generated/resources/main/win32-x86-64
#
# - name: Build
# run: |
# models\download-ggml-model.cmd tiny.en
# cd bindings/java
# chmod +x ./gradlew
# ./gradlew build
#
# - name: Upload jar
# uses: actions/upload-artifact@v4
# with:
# name: whispercpp.jar
# path: bindings/java/build/libs/whispercpp-*.jar
#
# - name: Publish package
# if: ${{ github.ref == 'refs/heads/master' }}
# uses: gradle/gradle-build-action@v2.4.2
@ -1117,8 +807,6 @@ jobs:
# PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
quantize:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04
steps:
@ -1131,95 +819,3 @@ jobs:
cmake -B build
cmake --build build --config Release
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
release:
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' }}
runs-on: ubuntu-latest
needs:
- determine-tag
- ios-xcode-build
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: release
evict-old-files: 1d
# Downloads all the artifacts from the previous jobs
- name: Download artifacts
id: download-artifact
uses: actions/download-artifact@v4
with:
path: ./artifact
- name: Move artifacts
id: move_artifacts
run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release
- name: Create release
id: create_release
uses: ggml-org/action-create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ needs.determine-tag.outputs.tag_name }}
prerelease: ${{ github.event.inputs.pre_release_tag != '' }}
- name: Upload release
id: upload_release
uses: actions/github-script@v3
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const path = require('path');
const fs = require('fs');
const release_id = '${{ steps.create_release.outputs.id }}';
for (let file of await fs.readdirSync('./artifact/release')) {
if (path.extname(file) === '.zip') {
console.log('uploadReleaseAsset', file);
await github.repos.uploadReleaseAsset({
owner: context.repo.owner,
repo: context.repo.repo,
release_id: release_id,
name: file,
data: await fs.readFileSync(`./artifact/release/${file}`)
});
}
}
coreml-base-en:
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
runs-on: macos-latest
needs: determine-tag
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set environment variables
id: set_vars
run: |
echo "MODEL_NAME=base.en" >> $GITHUB_ENV
echo "GEN_MODEL_NAME=whisper-${{ needs.determine-tag.outputs.tag_name }}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV
- name: Download model
run: |
./models/download-ggml-model.sh ${{ env.MODEL_NAME }}
- name: Generate CoreML model
run: |
python3.11 -m venv venv
source venv/bin/activate
pip install ane_transformers openai-whisper coremltools
./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}

View File

@ -18,7 +18,6 @@ jobs:
matrix:
config:
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
- { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" }
#TODO: the cuda image keeps failing - disable for now
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }

View File

@ -1,91 +0,0 @@
name: Examples WASM
on:
push:
branches: ["master"]
workflow_dispatch:
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
deploy-wasm-github-pages:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v4
- name: Setup emsdk
uses: mymindstorm/setup-emsdk@v14
- name: Build WASM Examples
# Enable for real build later in whisper.cpp
run: |
mkdir -p build-em && cd build-em
emcmake cmake .. -DCMAKE_BUILD_TYPE=Release
make -j
- name: Create staging directory
run: mkdir -p staging
- name: Create .nojekyll file in staging directory
run: touch staging/.nojekyll
- name: Copy application files
run: |
build_dir=build-em/bin
ls ${build_dir}
# command.wasm
target_dir=staging/command.wasm
mkdir -p ${target_dir}
cp ${build_dir}/command.wasm/{index.html,command.js,helpers.js} ${target_dir}
cp ${build_dir}/libcommand.js ${target_dir}
# bench.wasm
target_dir=staging/bench.wasm
mkdir -p ${target_dir}
cp ${build_dir}/bench.wasm/{index.html,bench.js,helpers.js} ${target_dir}
cp ${build_dir}/libbench.js ${target_dir}
# stream.wasm
target_dir=staging/stream.wasm
mkdir -p ${target_dir}
cp ${build_dir}/stream.wasm/{index.html,stream.js,helpers.js} ${target_dir}
cp ${build_dir}/libstream.js ${target_dir}
# whisper.wasm (this will be the main example page)
target_dir=staging
mkdir -p ${target_dir}
cp ${build_dir}/whisper.wasm/{index.html,main.js,helpers.js} ${target_dir}
cp ${build_dir}/libmain.js ${target_dir}
# Copy Cross-Origin Isolation service worker
cp -v examples/coi-serviceworker.js staging/
- name: List files in staging directory (for debugging)
run: |
echo "Files in staging directory:"
find staging -type f | sort
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./staging
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

0
.gitmodules vendored Normal file
View File

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.7.5)
project("whisper.cpp" VERSION 1.7.4)
include(CheckIncludeFileCXX)
set(SOVERSION 1)
@ -38,13 +38,8 @@ if (EMSCRIPTEN)
# TODO: without these, we get the following error:
# wasm-ld: error: --shared-memory is disallowed by whisper.cpp.o because it was not compiled with 'atomics' or 'bulk-memory' features.
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -s TOTAL_STACK=5242880")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s TOTAL_STACK=5242880")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -s TOTAL_STACK=5242880")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -s TOTAL_STACK=5242880")
else()
if (MINGW)
set(BUILD_SHARED_LIBS_DEFAULT OFF)
@ -67,8 +62,7 @@ option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings"
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
# build
option(WHISPER_FATAL_WARNINGS "whisper: enable -Werror flag" OFF)
option(WHISPER_USE_SYSTEM_GGML "whisper: use system-installed GGML library" OFF)
option(WHISPER_FATAL_WARNINGS "whisper: enable -Werror flag" OFF)
# sanitizers
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
@ -127,31 +121,7 @@ whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16)
#
if (NOT TARGET ggml)
if (WHISPER_USE_SYSTEM_GGML)
find_package(ggml REQUIRED)
if (NOT ggml_FOUND)
message(FATAL_ERROR "System-installed GGML library not found.")
endif()
add_library(ggml ALIAS ggml::ggml)
else()
add_subdirectory(ggml)
if(WIN32)
# The following adds a _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR macro and is a workaround for
# the Windows C++ standard library which does not support constexpr mutexes.
# From the release notes://github.com/microsoft/STL/wiki/Changelog
# Disable constexpr mutex constructor on Windows
# Fixed mutex's constructor to be constexpr. #3824 #4000 #4339
# Note: Programs that aren't following the documented restrictions on binary compatibility may encounter
# null dereferences in mutex machinery. You must follow this rule:
# When you mix binaries built by different supported versions of the toolset, the Redistributable version
# must be at least as new as the latest toolset used by any app component.
# You can define _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR as an escape hatch.
#
# Specifically to whisper.cpp this would cause a crash when using the Java bindings.
# resulting in a Invalid memory access error.
target_compile_definitions(ggml-base PRIVATE _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()
endif()
add_subdirectory(ggml)
# ... otherwise assume ggml is added by a parent CMakeLists.txt
endif()
add_subdirectory(src)
@ -206,43 +176,10 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc"
#
if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
include(CTest)
add_subdirectory(tests)
#include(CTest)
#add_subdirectory(tests)
endif ()
if (WHISPER_BUILD_EXAMPLES)
add_subdirectory(examples)
endif()
if (MSVC)
set(MSVC_WARNING_FLAGS
/wd4101 # Unreferenced local variable
/wd4005 # Macro redefinition
/wd4065 # switch statement contains 'default' but no 'case' labels
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
/wd4244 # Conversion from one type to another type, possible loss of ata
/wd4805 # Unsafe mix of type
/wd4305 # Truncation from 'type1' to 'type2' (often double to float)
/wd4996 # Function or variable may be unsafe/deprecated
)
function(disable_msvc_warnings target_name)
if(TARGET ${target_name})
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
endif()
endfunction()
if (WHISPER_BUILD_EXAMPLES)
disable_msvc_warnings(whisper)
disable_msvc_warnings(common)
disable_msvc_warnings(common-sdl)
disable_msvc_warnings(lsp)
disable_msvc_warnings(wchess-core)
disable_msvc_warnings(whisper-command)
disable_msvc_warnings(whisper-cli)
disable_msvc_warnings(whisper-server)
disable_msvc_warnings(whisper-stream)
disable_msvc_warnings(whisper-talk-llama)
disable_msvc_warnings(whisper-bench)
disable_msvc_warnings(quantize)
endif()
endif()

View File

@ -4,7 +4,7 @@
.PHONY: build
build:
cmake -B build $(CMAKE_ARGS)
cmake -B build
cmake --build build --config Release
# download a few audio samples into folder "./samples":
@ -41,17 +41,17 @@ samples:
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo:
bash ./models/download-ggml-model.sh $@
cmake -B build $(CMAKE_ARGS)
cmake -B build
cmake --build build --config Release
@echo ""
@echo "==============================================="
@echo "Running $@ on all samples in ./samples ..."
@echo "==============================================="
@echo ""
@for f in samples/*.{flac,mp3,ogg,wav}; do \
@for f in samples/*$(.flac .mp3 .ogg .wav); do \
echo "----------------------------------------------" ; \
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
echo "----------------------------------------------" ; \
echo "----------------------------------------------" ; \
echo "" ; \
./build/bin/whisper-cli -m models/ggml-$@.bin -f $$f ; \
echo "" ; \

161
README.md
View File

@ -2,12 +2,15 @@
![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
[![Actions Status](https://github.com/ggml-org/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggml-org/whisper.cpp/actions)
[![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.7.5](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.5) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
> [!NOTE]
> New maintenance roadmap: https://github.com/ggerganov/whisper.cpp/discussions/2788
Stable: [v1.7.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -23,8 +26,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- [Efficient GPU support for NVIDIA](#nvidia-gpu-support)
- [OpenVINO Support](#openvino-support)
- [Ascend NPU Support](#ascend-npu-support)
- [Moore Threads GPU Support](#moore-threads-gpu-support)
- [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/include/whisper.h)
Supported platforms:
@ -32,14 +34,14 @@ Supported platforms:
- [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android)
- [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggml-org/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggml-org/whisper.cpp/discussions/166)
- [x] [Docker](https://github.com/ggml-org/whisper.cpp/pkgs/container/whisper.cpp)
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
- [x] [Docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp).
The rest of the code is part of the [`ggml`](https://github.com/ggml-org/ggml) machine learning library.
The rest of the code is part of the [`ggml`](https://github.com/ggerganov/ggml) machine learning library.
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
@ -52,14 +54,14 @@ https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a
On Apple Silicon, the inference runs fully on the GPU via Metal:
https://github.com/ggml-org/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
## Quick start
First clone the repository:
```bash
git clone https://github.com/ggml-org/whisper.cpp.git
git clone https://github.com/ggerganov/whisper.cpp.git
```
Navigate into the directory:
@ -150,7 +152,6 @@ standard cmake setup with:
cmake -B build -DGGML_BLAS=1
cmake --build build --config Release
./build/bin/whisper-cli [ .. etc .. ]
```
## Quantization
@ -183,11 +184,11 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
```
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
- Python 3.11 is recommended.
- Python 3.10 is recommended.
- MacOS Sonoma (version 14) or newer is recommended, as older versions of MacOS might experience issues with transcription hallucination.
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
- To create an environment, use: `conda create -n py311-whisper python=3.11 -y`
- To activate the environment, use: `conda activate py311-whisper`
- To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
- To activate the environment, use: `conda activate py310-whisper`
- Generate a Core ML model. For example, to generate a `base.en` model, use:
@ -224,7 +225,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggml-org/whisper.cpp/pull/566).
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## OpenVINO support
@ -309,7 +310,7 @@ This can result in significant speedup in encoder performance. Here are the inst
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run.
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggml-org/whisper.cpp/pull/1037).
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support
@ -323,12 +324,6 @@ cmake -B build -DGGML_CUDA=1
cmake --build build -j --config Release
```
or for newer NVIDIA GPU's (RTX 5000 series):
```
cmake -B build -DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"
cmake --build build -j --config Release
```
## Vulkan GPU support
Cross-vendor solution which allows you to accelerate workload on your GPU.
First, make sure your graphics card driver provides support for Vulkan API.
@ -382,56 +377,6 @@ Run the inference examples as usual, for example:
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
## Moore Threads GPU support
With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels.
First, make sure you have installed `MUSA SDK rc3.1.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=rc3.1.1
Now build `whisper.cpp` with MUSA support:
```
cmake -B build -DGGML_MUSA=1
cmake --build build -j --config Release
```
or specify the architecture for your Moore Threads GPU. For example, if you have a MTT S80 GPU, you can specify the architecture as follows:
```
cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21"
cmake --build build -j --config Release
```
## FFmpeg support (Linux only)
If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration.
First, you need to install required libraries:
```bash
# Debian/Ubuntu
sudo apt install libavcodec-dev libavformat-dev libavutil-dev
# RHEL/Fedora
sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-devel
```
Then you can build the project as follows:
```bash
cmake -B build -D WHISPER_FFMPEG=yes
cmake --build build
```
Run the following example to confirm it's working:
```bash
# Convert an audio file to Opus format
ffmpeg -i samples/jfk.wav jfk.opus
# Transcribe the audio file
./build/bin/whisper-cli --model models/ggml-base.en.bin --file jfk.opus
```
## Docker
### Prerequisites
@ -443,9 +388,8 @@ ffmpeg -i samples/jfk.wav jfk.opus
We have two Docker images available for this project:
1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`)
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
### Usage
@ -458,11 +402,11 @@ docker run -it --rm \
docker run -it --rm \
-v path/to/models:/models \
-v path/to/audios:/audios \
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav"
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
# transcribe an audio file in samples folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav"
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
```
## Installing with Conan
@ -483,8 +427,7 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
This is a naive example of performing real-time inference on audio from your microphone.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
More info is available in [issue #10](https://github.com/ggml-org/whisper.cpp/issues/10).
You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly.
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```bash
cmake -B build -DWHISPER_SDL2=ON
@ -572,7 +515,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
## Speaker segmentation via tinydiarize (experimental)
More information about this approach is available here: https://github.com/ggml-org/whisper.cpp/pull/1058
More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
Sample usage:
@ -636,7 +579,7 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
## Video comparison of different models
Use the [scripts/bench-wts.sh](https://github.com/ggml-org/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format:
Use the [scripts/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format:
```bash
./scripts/bench-wts.sh samples/jfk.wav
@ -653,7 +596,7 @@ In order to have an objective comparison of the performance of the inference acr
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
took to execute it. The results are summarized in the following Github issue:
[Benchmark results](https://github.com/ggml-org/whisper.cpp/issues/89)
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py).
@ -680,24 +623,25 @@ You can download the converted models using the [models/download-ggml-model.sh](
or manually from here:
- https://huggingface.co/ggerganov/whisper.cpp
- https://ggml.ggerganov.com
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md).
## [Bindings](https://github.com/ggml-org/whisper.cpp/discussions/categories/bindings)
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggml-org/whisper.cpp/discussions/310)
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggml-org/whisper.cpp/discussions/309)
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggml-org/whisper.cpp/discussions/312)
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [x] Java:
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggml-org/whisper.cpp/discussions/507)
- [x] Objective-C / Swift: [ggml-org/whisper.spm](https://github.com/ggml-org/whisper.spm) | [#313](https://github.com/ggml-org/whisper.cpp/discussions/313)
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [x] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
- [x] .NET: | [#422](https://github.com/ggml-org/whisper.cpp/discussions/422)
- [x] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [x] Python: | [#9](https://github.com/ggml-org/whisper.cpp/issues/9)
- [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
@ -705,33 +649,6 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
- [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
- [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
## XCFramework
The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS,
and macOS. It can be used in Swift projects without the need to compile the
library from source. For examples:
```swift
// swift-tools-version: 5.10
// The swift-tools-version declares the minimum version of Swift required to build this package.
import PackageDescription
let package = Package(
name: "Whisper",
targets: [
.executableTarget(
name: "Whisper",
dependencies: [
"WhisperFramework"
]),
.binaryTarget(
name: "WhisperFramework",
url: "https://github.com/ggml-org/whisper.cpp/releases/download/v1.7.5/whisper-v1.7.5-xcframework.zip",
checksum: "c7faeb328620d6012e130f3d705c51a6ea6c995605f2df50f6e1ad68c59c6c4a"
)
]
)
```
## Examples
There are various examples of using the library for different projects in the [examples](examples) folder.
@ -750,13 +667,13 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggml-org/whisper.cpp/issues/185) |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
## [Discussions](https://github.com/ggml-org/whisper.cpp/discussions)
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
You can use the [Show and tell](https://github.com/ggml-org/whisper.cpp/discussions/categories/show-and-tell) category
You can use the [Show and tell](https://github.com/ggerganov/whisper.cpp/discussions/categories/show-and-tell) category
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
[Frequently asked questions (#126)](https://github.com/ggml-org/whisper.cpp/discussions/126) discussion.
[Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126) discussion.

View File

@ -11,11 +11,11 @@ UNAME_M := $(shell uname -m)
endif
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
BUILD_DIR := build_go
BUILD_DIR := build
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include)
LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src:$(abspath ../../${BUILD_DIR}/ggml/src))
LIBRARY_PATH := $(abspath ../..)
ifeq ($(GGML_CUDA),1)
LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/
@ -29,10 +29,8 @@ endif
all: clean whisper examples
whisper: mkdir
cmake -S ../.. -B ../../${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=Release \
-DBUILD_SHARED_LIBS=OFF
cmake --build ../../${BUILD_DIR} --target whisper
@echo Build whisper
@${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy
ifeq ($(UNAME_S),Darwin)

View File

@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil, nil, nil); err != nil {
if err := context.Process(samples, nil, nil); err != nil {
return err
}
@ -51,7 +51,7 @@ func main() {
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
```bash
git clone https://github.com/ggml-org/whisper.cpp.git
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/go
make test
```
@ -98,7 +98,7 @@ The API Documentation:
Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312)
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License

View File

@ -1,5 +1,5 @@
/*
github.com/ggml-org/whisper.cpp/bindings/go
github.com/ggerganov/whisper.cpp/bindings/go
provides a speech-to-text service bindings for the Go programming language.
*/
package whisper

View File

@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, nil, cb, nil); err != nil {
if err := context.Process(data, cb, nil); err != nil {
return err
}

View File

@ -71,10 +71,6 @@ func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language())
}
func (context *context) DetectedLanguage() string {
return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id())
}
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
@ -193,7 +189,6 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
@ -208,20 +203,7 @@ func (context *context) Process(
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
func(new int) {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
@ -229,11 +211,22 @@ func (context *context) Process(
callNewSegment(toSegment(context.model.ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}

View File

@ -88,37 +88,6 @@ func TestProcess(t *testing.T) {
context, err := model.NewContext()
assert.NoError(err)
err = context.Process(data, nil, nil, nil)
err = context.Process(data, nil, nil)
assert.NoError(err)
}
func TestDetectedLanguage(t *testing.T) {
assert := assert.New(t)
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
assert.NoError(err)
assert.Equal(uint16(1), dec.NumChans)
data := buf.AsFloat32Buffer().Data
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
context, err := model.NewContext()
assert.NoError(err)
err = context.Process(data, nil, nil, nil)
assert.NoError(err)
expectedLanguage := "en"
actualLanguage := context.DetectedLanguage()
assert.Equal(expectedLanguage, actualLanguage)
}

View File

@ -16,10 +16,6 @@ type SegmentCallback func(Segment)
// processing. It is called during the Process function
type ProgressCallback func(int)
// EncoderBeginCallback is the callback function for checking if we want to
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
@ -35,13 +31,12 @@ type Model interface {
Languages() []string
}
// Context is the speech recognition context.
// Context is the speach recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
DetectedLanguage() string // Get detected language
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
@ -63,7 +58,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error
Process([]float32, SegmentCallback, ProgressCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.

View File

@ -9,7 +9,7 @@ import (
// CGO
/*
#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ -fopenmp
#cgo LDFLAGS: -lwhisper -lm -lstdc++ -fopenmp
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
#include <whisper.h>
#include <stdlib.h>

View File

@ -31,10 +31,10 @@ public class Example {
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// custom configuration if required
whisperParams.temperature_inc = 0f;
var samples = readAudio(); // divide each value by 32767.0f
whisper.fullTranscribe(whisperParams, samples);
int segmentCount = whisper.getTextSegmentCount(context);
for (int i = 0; i < segmentCount; i++) {
String text = whisper.getTextSegment(context, i);
@ -52,7 +52,7 @@ public class Example {
In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
```bash
git clone https://github.com/ggml-org/whisper.cpp.git
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java
./gradlew build

View File

@ -25,43 +25,25 @@ sourceSets {
}
tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build/src'
include 'libwhisper.dylib'
into 'build/generated/resources/main'
from '../../build'
include 'libwhisper.dynlib'
into 'build/generated/resources/main/darwin'
}
tasks.register('copyLibwhisperSo', Copy) {
from '../../build/src'
from '../../build'
include 'libwhisper.so'
into 'build/generated/resources/main'
into 'build/generated/resources/main/linux-x86-64'
}
tasks.register('copyWhisperDLL', Copy) {
from '../../build/bin/Release'
tasks.register('copyWhisperDll', Copy) {
from '../../build/Release'
include 'whisper.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_BASE_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml-base.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_CPU_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml-cpu.dll'
into 'build/generated/resources/main'
into 'build/generated/resources/main/windows-x86-64'
}
tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDLL, copyGGML_BASE_DLL, copyGGML_DLL, copyGGML_CPU_DLL
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
}
test {
@ -73,12 +55,7 @@ java {
withJavadocJar()
}
sourcesJar() {
dependsOn copyLibs
}
jar {
dependsOn copyLibs
exclude '**/whisper_java.exp', '**/whisper_java.lib'
}
@ -90,9 +67,6 @@ tasks.withType(Test) {
useJUnitPlatform()
}
test.dependsOn copyLibs
processResources.dependsOn copyLibs
dependencies {
implementation "net.java.dev.jna:jna:5.13.0"
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"

0
bindings/java/gradlew vendored Executable file → Normal file
View File

View File

@ -1,24 +0,0 @@
package io.github.ggerganov.whispercpp;
/**
* Presets for alignment heads in DTW token timestamps
*/
public class WhisperConstants {
// Alignment heads presets
public static final int WHISPER_AHEADS_NONE = 0;
public static final int WHISPER_AHEADS_TINY_EN = 1;
public static final int WHISPER_AHEADS_TINY = 2;
public static final int WHISPER_AHEADS_BASE_EN = 3;
public static final int WHISPER_AHEADS_BASE = 4;
public static final int WHISPER_AHEADS_SMALL_EN = 5;
public static final int WHISPER_AHEADS_SMALL = 6;
public static final int WHISPER_AHEADS_MEDIUM_EN = 7;
public static final int WHISPER_AHEADS_MEDIUM = 8;
public static final int WHISPER_AHEADS_LARGE_V1 = 9;
public static final int WHISPER_AHEADS_LARGE_V2 = 10;
public static final int WHISPER_AHEADS_LARGE_V3 = 11;
public static final int WHISPER_AHEADS_LARGE_V3_TURBO = 12;
public static final int WHISPER_AHEADS_CUSTOM = 13;
public static final int WHISPER_AHEADS_N_TOP_MOST = 14;
public static final int WHISPER_AHEADS_COUNT = 15;
}

View File

@ -1,9 +1,7 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.NativeLong;
import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
@ -11,26 +9,33 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import java.util.List;
public class WhisperContext extends Structure {
public NativeLong t_load_us;
public NativeLong t_start_us;
int t_load_us = 0;
int t_start_us = 0;
/** weight type (FP32 / FP16 / QX) */
public GgmlType wtype = GgmlType.GGML_TYPE_F16;
GgmlType wtype = GgmlType.GGML_TYPE_F16;
/** intermediate type (FP32 or FP16) */
public GgmlType itype = GgmlType.GGML_TYPE_F16;
GgmlType itype = GgmlType.GGML_TYPE_F16;
public WhisperContextParams.ByValue params;
public Pointer model;
public Pointer vocab;
public Pointer state;
// WhisperModel model;
public PointerByReference model;
// whisper_vocab vocab;
// whisper_state * state = nullptr;
public PointerByReference vocab;
public PointerByReference state;
/** populated by whisper_init_from_file_with_params() */
public Pointer path_model;
String path_model;
WhisperContextParams params;
@Override
protected List<String> getFieldOrder() {
return List.of("t_load_us", "t_start_us", "wtype", "itype",
"params", "model", "vocab", "state", "path_model");
}
// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }
//
// public static class ByValue extends WhisperContext implements Structure.ByValue {
// }
//
// @Override
// protected List<String> getFieldOrder() {
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
// }
}

View File

@ -43,11 +43,11 @@ public class WhisperCpp implements AutoCloseable {
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
* @param params - params to use when initialising the context
*/
public void initContext(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException {
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
initContextImpl(modelPath, params);
}
private void initContextImpl(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException {
private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
@ -69,13 +69,15 @@ public class WhisperCpp implements AutoCloseable {
/**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Returns a ByValue instance to ensure proper parameter passing to native code.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/
public WhisperContextParams.ByValue getContextDefaultParams() {
WhisperContextParams.ByValue valueParams = new WhisperContextParams.ByValue(
lib.whisper_context_default_params_by_ref());
valueParams.read();
return valueParams;
public WhisperContextParams getContextDefaultParams() {
paramsPointer = lib.whisper_context_default_params_by_ref();
WhisperContextParams params = new WhisperContextParams(paramsPointer);
params.read();
return params;
}
/**
@ -86,7 +88,7 @@ public class WhisperCpp implements AutoCloseable {
*
* @param strategy - GREEDY
*/
public WhisperFullParams.ByValue getFullDefaultParams(WhisperSamplingStrategy strategy) {
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
Pointer pointer;
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
@ -102,7 +104,7 @@ public class WhisperCpp implements AutoCloseable {
pointer = beamParamsPointer;
}
WhisperFullParams.ByValue params = new WhisperFullParams.ByValue(pointer);
WhisperFullParams params = new WhisperFullParams(pointer);
params.read();
return params;
}
@ -136,21 +138,15 @@ public class WhisperCpp implements AutoCloseable {
}
/**
* Run the entire model: PCM -&gt; log mel spectrogram -&gt; encoder -&gt; decoder -&gt; text.
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
public String fullTranscribe(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException {
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
/*
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
valueParams.read();
*/
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}
@ -167,17 +163,12 @@ public class WhisperCpp implements AutoCloseable {
return str.toString().trim();
}
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
valueParams.read();
if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}

View File

@ -9,7 +9,6 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library {
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
String whisper_print_system_info();
@ -39,7 +38,7 @@ public interface WhisperCppJnaLibrary extends Library {
* @param params Pointer to whisper_context_params
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams.ByValue params);
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
@ -181,12 +180,12 @@ public interface WhisperCppJnaLibrary extends Library {
/**
* @return the id of the specified language, returns -1 if not found.
* Examples:
* "de" -&gt; 2
* "german" -&gt; 2
* "de" -> 2
* "german" -> 2
*/
int whisper_lang_id(String lang);
/** @return the short string of the specified language id (e.g. 2 -&gt; "de"), returns nullptr if not found */
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
String whisper_lang_str(int id);
/**
@ -269,21 +268,20 @@ public interface WhisperCppJnaLibrary extends Library {
void whisper_free_params(Pointer params);
/**
* Run the entire model: PCM -&gt; log mel spectrogram -&gt; encoder -&gt; decoder -&gt; text
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
int whisper_full(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples);
int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
public int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams.ByValue params, float[] samples, int n_samples);
//int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
int whisper_full_parallel(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples, int n_processors);
int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
/**
* Number of generated text segments.

View File

@ -1,17 +0,0 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
/**
* Callback for aborting GGML computation
* Maps to the C typedef: bool (*ggml_abort_callback)(void * data)
*/
public interface GgmlAbortCallback extends Callback {
/**
* Return true to abort the computation, false to continue
*
* @param data User data passed to the callback
* @return true to abort, false to continue
*/
boolean invoke(com.sun.jna.Pointer data);
}

View File

@ -1,30 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import java.util.Arrays;
import java.util.List;
public class WhisperAhead extends Structure {
public int n_text_layer;
public int n_head;
public WhisperAhead() {
super();
}
public WhisperAhead(int textLayer, int head) {
super();
this.n_text_layer = textLayer;
this.n_head = head;
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("n_text_layer", "n_head");
}
public static class ByReference extends WhisperAhead implements Structure.ByReference {}
public static class ByValue extends WhisperAhead implements Structure.ByValue {}
}

View File

@ -1,41 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import java.util.Arrays;
import java.util.List;
public class WhisperAheads extends Structure {
public NativeLong n_heads;
public Pointer heads;
public WhisperAheads() {
super();
}
/**
* Create alignment heads from an array of WhisperAhead objects
*/
public void setHeads(WhisperAhead[] aheadsArray) {
this.n_heads = new NativeLong(aheadsArray.length);
int structSize = aheadsArray[0].size();
Memory mem = new Memory(structSize * aheadsArray.length);
for (int i = 0; i < aheadsArray.length; i++) {
aheadsArray[i].write();
byte[] buffer = aheadsArray[i].getPointer().getByteArray(0, structSize);
mem.write(i * structSize, buffer, 0, buffer.length);
}
this.heads = mem;
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("n_heads", "heads");
}
public static class ByReference extends WhisperAheads implements Structure.ByReference {}
public static class ByValue extends WhisperAheads implements Structure.ByValue {}
}

View File

@ -1,5 +1,7 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import java.util.Arrays;
import java.util.List;
@ -9,73 +11,21 @@ import java.util.List;
* whisper_context_default_params()
*/
public class WhisperContextParams extends Structure {
public WhisperContextParams(Pointer p) {
super(p);
}
public WhisperContextParams() {
super();
}
/** Use GPU for inference (default = true) */
/** Use GPU for inference Number (default = true) */
public CBool use_gpu;
/** Use flash attention (default = false) */
public CBool flash_attn;
/** CUDA device to use (default = 0) */
public int gpu_device;
/** [EXPERIMENTAL] Enable token-level timestamps with DTW (default = false) */
public CBool dtw_token_timestamps;
/** [EXPERIMENTAL] Alignment heads preset for DTW */
public int dtw_aheads_preset;
/** Number of top layers to use for DTW when using WHISPER_AHEADS_N_TOP_MOST preset */
public int dtw_n_top;
public WhisperAheads.ByValue dtw_aheads;
/** DTW memory size (internal use) */
public NativeLong dtw_mem_size;
/** Use GPU for inference */
/** Use GPU for inference Number (default = true) */
public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
}
/** Use flash attention */
public void useFlashAttn(boolean enable) {
flash_attn = enable ? CBool.TRUE : CBool.FALSE;
}
/** Enable DTW token-level timestamps */
public void enableDtwTokenTimestamps(boolean enable) {
dtw_token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** Set DTW alignment heads preset */
public void setDtwAheadsPreset(int preset) {
dtw_aheads_preset = preset;
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList(
"use_gpu",
"flash_attn",
"gpu_device",
"dtw_token_timestamps",
"dtw_aheads_preset",
"dtw_n_top",
"dtw_aheads",
"dtw_mem_size"
);
}
public static class ByValue extends WhisperContextParams implements Structure.ByValue {
public ByValue() { super(); }
public ByValue(Pointer p) { super(p); }
return Arrays.asList("use_gpu");
}
}

View File

@ -5,7 +5,6 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
import io.github.ggerganov.whispercpp.callbacks.GgmlAbortCallback;
import java.util.Arrays;
import java.util.List;
@ -17,12 +16,10 @@ import java.util.List;
*/
public class WhisperFullParams extends Structure {
public WhisperFullParams() {
super();
}
public WhisperFullParams(Pointer p) {
super(p);
// super(p, ALIGN_MSVC);
// super(p, ALIGN_GNUC);
}
/** Sampling strategy for whisper_full() function. */
@ -72,10 +69,10 @@ public class WhisperFullParams extends Structure {
single_segment = single ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print special tokens (e.g., &lt;SOT&gt;, &lt;EOT&gt;, &lt;BEG&gt;, etc.). (default = false) */
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public CBool print_special;
/** Flag to print special tokens (e.g., &lt;SOT&gt;, &lt;EOT&gt;, &lt;BEG&gt;, etc.). (default = false) */
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public void printSpecial(boolean enable) {
print_special = enable ? CBool.TRUE : CBool.FALSE;
}
@ -132,14 +129,6 @@ public class WhisperFullParams extends Structure {
/** Maximum tokens per segment (0, default = no limit) */
public int max_tokens;
/** [EXPERIMENTAL] Enable debug mode for extra info */
public CBool debug_mode;
/** Enable debug mode */
public void enableDebugMode(boolean enable) {
debug_mode = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */
public int audio_ctx;
@ -285,16 +274,6 @@ public class WhisperFullParams extends Structure {
*/
public Pointer encoder_begin_callback_user_data;
/** Callback used to abort GGML computation */
public Pointer abort_callback;
/** User data for the abort_callback */
public Pointer abort_callback_user_data;
public void setAbortCallback(GgmlAbortCallback callback) {
abort_callback = CallbackReference.getFunctionPointer(callback);
}
/**
* Callback by each decoder to filter obtained logits.
* WhisperLogitsFilterCallback
@ -331,28 +310,17 @@ public class WhisperFullParams extends Structure {
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx",
"offset_ms", "duration_ms", "translate", "no_context",
"no_timestamps", "single_segment", "print_special",
"print_progress", "print_realtime", "print_timestamps",
"token_timestamps", "thold_pt", "thold_ptsum", "max_len",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt",
"prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_nst", "temperature",
"max_initial_ts", "length_penalty", "temperature_inc",
"entropy_thold", "logprob_thold", "no_speech_thold", "greedy",
"beam_search", "new_segment_callback", "new_segment_callback_user_data",
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment", "no_timestamps",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"abort_callback", "abort_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data",
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
}
public static class ByValue extends WhisperFullParams implements Structure.ByValue {
public ByValue() { super(); }
public ByValue(Pointer p) { super(p); }
}
}

View File

@ -76,7 +76,7 @@ class WhisperCppTest {
float[] floats = new float[b.length / 2];
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
//params.initial_prompt = "and so my fellow Americans um, like";

View File

@ -33,9 +33,6 @@ mkdir build-em && cd build-em
emcmake cmake .. && make -j
# run test
node ../tests/test-whisper.js
# For Node.js versions prior to v16.4.0, experimental features need to be enabled:
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
# publish npm package

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.7.5",
"version": "1.7.4",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

View File

@ -1,6 +1,3 @@
LICENSE
pkg/
lib/whisper.*
ext/sources/*
!ext/sources/CMakeGraphVizOptions.cmake
ext/mkmf.log

View File

@ -16,18 +16,6 @@ If bundler is not being used to manage dependencies, install the gem by executin
$ gem install whispercpp
You can pass build options for whisper.cpp, for instance:
$ bundle config build.whispercpp --enable-ggml-cuda
or,
$ gem install whispercpp -- --enable-ggml-cuda
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options.
For boolean options like `GGML_CUDA`, the README says `-DGGML_CUDA=1`. You need strip `-D`, prepend `--enable-` for `1` or `ON` (`--disable-` for `0` or `OFF`) and make it kebab-case: `--enable-ggml-cuda`.
For options which require arguments like `CMAKE_CUDA_ARCHITECTURES`, the README says `-DCMAKE_CUDA_ARCHITECTURES="86"`. You need strip `-D`, prepend `--`, make it kebab-case, append `=` and append argument: `--cmake-cuda-architectures="86"`.
Usage
-----
@ -240,7 +228,7 @@ The second argument `samples` may be an array, an object with `length` and `each
Development
-----------
% git clone https://github.com/ggml-org/whisper.cpp.git
% git clone https://github.com/ggerganov/whisper.cpp.git
% cd whisper.cpp/bindings/ruby
% rake test
@ -253,5 +241,5 @@ License
The same to [whisper.cpp][].
[whisper.cpp]: https://github.com/ggml-org/whisper.cpp
[models]: https://github.com/ggml-org/whisper.cpp/tree/master/models
[whisper.cpp]: https://github.com/ggerganov/whisper.cpp
[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models

View File

@ -3,15 +3,11 @@ require "bundler/gem_tasks"
require "rake/testtask"
require_relative "extsources"
SOURCES_DIR = "ext/sources"
SOURCES = FileList[]
EXTSOURCES.each do |src|
basename = src.pathmap("%f")
dest = basename == "LICENSE" ? basename
: src.pathmap("%{\\.\\./\\.\\.,#{SOURCES_DIR}}p")
.pathmap("%{\\.\\./javascript,#{SOURCES_DIR}/bindings/javascript}p")
dest = basename == "LICENSE" ? basename : src.pathmap("%{../..,ext}p")
dir = dest.pathmap("%d")
file src
directory dir
@ -22,6 +18,7 @@ EXTSOURCES.each do |src|
end
CLEAN.include SOURCES
CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
SRC = FileList["ext/*.{c,cpp,h}"]
@ -39,20 +36,6 @@ file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
ruby "extconf.rb"
end
end
if File.exist? "ext/Makefile"
task :make_clean do
cd "ext" do
sh "make", "clean"
end
end
task clean: :make_clean
task :make_distclean do
cd "ext" do
sh "make", "distclean"
end
end
task clobber: :make_distclean
end
file SO_FILE => "ext/Makefile" do |t|
chdir "ext" do

9
bindings/ruby/ext/cpu.mk Normal file
View File

@ -0,0 +1,9 @@
ggml/src/ggml-cpu/ggml-cpu-cpp.o: \
ggml/src/ggml-cpu/ggml-cpu.cpp \
ggml/include/ggml-backend.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/src/ggml-backend-impl.h \
ggml/include/ggml-cpu.h \
ggml/src/ggml-impl.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -1,61 +0,0 @@
require "tsort"
class Dependencies
def initialize(cmake, options)
@cmake = cmake
@options = options
generate_dot
@libs = parse_dot
end
def to_s
@libs.join(" ")
end
private
def dot_path
File.join(__dir__, "build", "whisper.cpp.dot")
end
def generate_dot
system @cmake, "-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF", @options.to_s, exception: true
end
def parse_dot
static_lib_shape = nil
nodes = {}
depends = Hash.new {|h, k| h[k] = []}
class << depends
include TSort
alias tsort_each_node each_key
def tsort_each_child(node, &block)
fetch(node, []).each(&block)
end
end
File.open(dot_path).each_line do |line|
case line
when /\[\s*label\s*=\s*"Static Library"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]/
static_lib_shape = $~[:shape]
when /\A\s*"(?<node>\w+)"\s*\[\s*label\s*=\s*"(?<label>\S+)"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]\s*;\s*\z/
node = $~[:node]
label = $~[:label]
shape = $~[:shape]
nodes[node] = [label, shape]
when /\A\s*"(?<depender>\w+)"\s*->\s*"(?<dependee>\w+)"/
depender = $~[:depender]
dependee = $~[:dependee]
depends[depender] ||= []
depends[depender] << dependee
end
end
depends.tsort.filter_map {|node|
label, shape = nodes[node]
shape == static_lib_shape ? label : nil
}.collect {|lib| "lib#{lib}.a"}
.reverse
end
end

View File

@ -1,22 +1,208 @@
require "mkmf"
require_relative "options"
require_relative "dependencies"
require 'mkmf'
cmake = find_executable("cmake") || abort
options = Options.new
have_library("gomp") rescue nil
libs = Dependencies.new(cmake, options)
# need to use c++ compiler flags
$CXXFLAGS << ' -std=c++17'
$INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples"
$LOCAL_LIBS << " #{libs}"
$cleanfiles << " build #{libs}"
$LDFLAGS << ' -lstdc++'
create_makefile "whisper" do |conf|
conf << <<~EOF
$(TARGET_SO): #{libs}
#{libs}: cmake-targets
cmake-targets:
#{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options}
#{"\t"}#{cmake} --build build --config Release --target common whisper
EOF
# Set to true when building binary gems
if enable_config('static-stdlib', false)
$LDFLAGS << ' -static-libgcc -static-libstdc++'
end
if enable_config('march-tune-native', false)
$CFLAGS << ' -march=native -mtune=native'
$CXXFLAGS << ' -march=native -mtune=native'
end
if ENV['WHISPER_METAL']
$GGML_METAL ||= true
$DEPRECATE_WARNING ||= true
end
$UNAME_S = `uname -s`.chomp
$UNAME_P = `uname -p`.chomp
$UNAME_M = `uname -m`.chomp
if $UNAME_S == 'Darwin'
unless ENV['GGML_NO_METAL']
$GGML_METAL ||= true
end
$GGML_NO_OPENMP ||= true
end
if $GGML_METAL
$GGML_METAL_EMBED_LIBRARY = true
end
$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples -DGGML_USE_CPU'
$MK_CFLAGS = '-std=c11 -fPIC'
$MK_CXXFLAGS = '-std=c++17 -fPIC'
$MK_NVCCFLAGS = '-std=c++17'
$MK_LDFLAGS = ''
$OBJ_GGML = []
$OBJ_WHISPER = []
$OBJ_COMMON = []
$OBJ_SDL = []
$MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600'
if $UNAME_S == 'Linux'
$MK_CPPFLAGS << ' -D_GNU_SOURCE'
end
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -D_DARWIN_C_SOURCE'
end
if ENV['WHISPER_DEBUG']
$MK_CFLAGS << ' -O0 -g'
$MK_CXXFLAGS << ' -O0 -g'
$MK_LDFLAGS << ' -g'
$MK_NVCCFLAGS << ' -O0 -g'
else
$MK_CPPFLAGS << ' -DNDEBUG'
$MK_CFLAGS << ' -O3'
$MK_CXXFLAGS << ' -O3'
$MK_NVCCFLAGS << ' -O3'
end
$WARN_FLAGS =
' -Wall' <<
' -Wextra' <<
' -Wpedantic' <<
' -Wcast-qual' <<
' -Wno-unused-function'
$MK_CFLAGS <<
$WARN_FLAGS <<
' -Wshadow' <<
' -Wstrict-prototypes' <<
' -Wpointer-arith' <<
' -Wmissing-prototypes' <<
' -Werror=implicit-int' <<
' -Werror=implicit-function-declaration'
$MK_CXXFLAGS <<
$WARN_FLAGS <<
' -Wmissing-declarations' <<
' -Wmissing-noreturn'
unless `#{cc_command} #{$LDFLAGS} -Wl,-v 2>&1`.chomp.include? 'dyld-1015.7'
$MK_CPPFLAGS << ' -DHAVE_BUGGY_APPLE_LINKER'
end
if %w[Linux Darwin FreeBSD NetBSD OpenBSD Haiku].include? $UNAME_S
$MK_CFLAGS << ' -pthread'
$MK_CXXFLAGS << ' -pthread'
end
unless $_WIN32
$DSO_EXT = '.so'
else
$DSO_EXT = '.dll'
end
unless ENV['RISCV']
if %w[x86_64 i686 amd64].include? $UNAME_M
$HOST_CXXFLAGS ||= ''
$MK_CFLAGS << ' -march=native -mtune=native'
$HOST_CXXFLAGS << ' -march=native -mtune=native'
end
else
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
end
unless ENV['GGML_NO_ACCELERATE']
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE'
$MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
$MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
$MK_LDFLAGS << ' -framework Accelerate'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
end
if ENV['GGML_OPENBLAS']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if ENV['GGML_OPENBLAS64']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas64`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if $GGML_METAL
$MK_CPPFLAGS << ' -DGGML_USE_METAL'
$MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal.o'
if ENV['GGML_METAL_NDEBUG']
$MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
end
if $GGML_METAL_EMBED_LIBRARY
$MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal-embed.o'
end
end
$OBJ_GGML <<
'ggml/src/ggml.o' <<
'ggml/src/ggml-alloc.o' <<
'ggml/src/ggml-backend.o' <<
'ggml/src/ggml-backend-reg.o' <<
'ggml/src/ggml-opt.o' <<
'ggml/src/ggml-quants.o' <<
'ggml/src/ggml-threading.o' <<
'ggml/src/ggml-cpu/ggml-cpu.o' <<
'ggml/src/ggml-cpu/ggml-cpu-cpp.o' <<
'ggml/src/ggml-cpu/ggml-cpu-aarch64.o' <<
'ggml/src/ggml-cpu/ggml-cpu-hbm.o' <<
'ggml/src/ggml-cpu/ggml-cpu-quants.o' <<
'ggml/src/ggml-cpu/ggml-cpu-traits.o'
$OBJ_WHISPER <<
'src/whisper.o' <<
'examples/common.o' <<
'examples/common-whisper.o'
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
$objs <<
"ruby_whisper.o" <<
"ruby_whisper_context.o" <<
"ruby_whisper_transcribe.o" <<
"ruby_whisper_params.o" <<
"ruby_whisper_error.o" <<
"ruby_whisper_segment.o" <<
"ruby_whisper_model.o"
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
$BASE_CXXFLAGS = "#{$MK_CXXFLAGS} #{$CXXFLAGS}"
$CXXFLAGS = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}"
$NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}"
$LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}"
create_makefile('whisper')
File.open 'Makefile', 'a' do |file|
file.puts 'include scripts/get-flags.mk'
file.puts 'include cpu.mk'
if $GGML_METAL
file.puts 'include metal.mk'
if $GGML_METAL_EMBED_LIBRARY
file.puts 'include metal-embed.mk'
end
end
end

View File

@ -0,0 +1,17 @@
ggml/src/ggml-metal/ggml-metal-embed.o: \
ggml/src/ggml-metal/ggml-metal.metal \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/src/ggml-common.h
@echo "Embedding Metal library"
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
$(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@
@rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s
@rmdir ${TEMP_ASSEMBLY}

View File

@ -0,0 +1,6 @@
ggml/src/ggml-metal/ggml-metal.o: \
ggml/src/ggml-metal/ggml-metal.m \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h
$(CC) $(CFLAGS) -c $< -o $@

View File

@ -1,219 +0,0 @@
class Options
def initialize
@options = {}
@pending_options = []
@ignored_options = []
configure
end
def help
@options
.collect_concat {|name, (type, value)|
option = option_name(name)
if type == :bool
["--enable-#{option}", "--disable-#{option}"]
else
"--#{option}=#{type.upcase}"
end
}
.join($/)
end
def to_s
@options
.reject {|name, (type, value)| value.nil?}
.collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"}
.join(" ")
end
def cmake_options
return @cmake_options if @cmake_options
output = nil
Dir.chdir __dir__ do
output = `cmake -S sources -B build -L`
end
started = false
@cmake_options = output.lines.filter_map {|line|
if line.chomp == "-- Cache values"
started = true
next
end
next unless started
option, value = line.chomp.split("=", 2)
name, type = option.split(":", 2)
[name, type, value]
}
end
def missing_options
cmake_options.collect {|name, type, value| name} -
@options.keys - @pending_options - @ignored_options
end
def extra_options
@options.keys + @pending_options - @ignored_options -
cmake_options.collect {|name, type, value| name}
end
private
def configure
filepath "ACCELERATE_FRAMEWORK"
ignored "BUILD_SHARED_LIBS"
ignored "BUILD_TESTING"
ignored "CMAKE_BUILD_TYPE"
ignored "CMAKE_INSTALL_PREFIX"
string "CMAKE_OSX_ARCHITECTURES"
ignored "CMAKE_OSX_DEPLOYMENT_TARGET"
string "CMAKE_OSX_SYSROOT"
filepath "FOUNDATION_LIBRARY"
bool "GGML_ACCELERATE"
bool "GGML_ALL_WARNINGS_3RD_PARTY"
bool "GGML_AMX_BF16"
bool "GGML_AMX_INT8"
bool "GGML_AMX_TILE"
bool "GGML_AVX"
bool "GGML_AVX2"
bool "GGML_AVX512"
bool "GGML_AVX512_BF16"
bool "GGML_AVX512_VBMI"
bool "GGML_AVX512_VNNI"
bool "GGML_AVX_VNNI"
ignored "GGML_BACKEND_DL"
ignored "GGML_BIN_INSTALL_DIR"
bool "GGML_BLAS"
string "GGML_BLAS_VENDOR"
bool "GGML_BMI2"
ignored "GGML_BUILD_EXAMPLES"
ignored "GGML_BUILD_TESTS"
filepath "GGML_CCACHE_FOUND"
bool "GGML_CPU"
bool "GGML_CPU_AARCH64"
ignored "GGML_CPU_ALL_VARIANTS"
string "GGML_CPU_ARM_ARCH"
bool "GGML_CPU_HBM"
bool "GGML_CPU_KLEIDIAI"
string "GGML_CPU_POWERPC_CPUTYPE"
bool "GGML_CUDA"
string "GGML_CUDA_COMPRESSION_MODE"
bool "GGML_CUDA_F16"
bool "GGML_CUDA_FA"
bool "GGML_CUDA_FA_ALL_QUANTS"
bool "GGML_CUDA_FORCE_CUBLAS"
bool "GGML_CUDA_FORCE_MMQ"
ignored "GGML_CUDA_GRAPHS"
bool "GGML_CUDA_NO_PEER_COPY"
bool "GGML_CUDA_NO_VMM"
string "GGML_CUDA_PEER_MAX_BATCH_SIZE"
bool "GGML_F16C"
bool "GGML_FMA"
bool "GGML_GPROF"
bool "GGML_HIP"
bool "GGML_HIP_GRAPHS"
bool "GGML_HIP_NO_VMM"
bool "GGML_HIP_ROCWMMA_FATTN"
ignored "GGML_INCLUDE_INSTALL_DIR"
bool "GGML_KOMPUTE"
bool "GGML_LASX"
ignored "GGML_LIB_INSTALL_DIR"
ignored "GGML_LLAMAFILE"
bool "GGML_LSX"
bool "GGML_LTO"
bool "GGML_METAL"
bool "GGML_METAL_EMBED_LIBRARY"
string "GGML_METAL_MACOSX_VERSION_MIN"
bool "GGML_METAL_NDEBUG"
bool "GGML_METAL_SHADER_DEBUG"
string "GGML_METAL_STD"
bool "GGML_METAL_USE_BF16"
bool "GGML_MUSA"
bool "GGML_NATIVE"
bool "GGML_OPENCL"
bool "GGML_OPENCL_EMBED_KERNELS"
bool "GGML_OPENCL_PROFILING"
string "GGML_OPENCL_TARGET_VERSION"
bool "GGML_OPENCL_USE_ADRENO_KERNELS"
bool "GGML_OPENMP"
bool "GGML_RPC"
bool "GGML_RVV"
bool "GGML_RV_ZFH"
pending "GGML_SCCACHE_FOUND"
string "GGML_SCHED_MAX_COPIES"
bool "GGML_SSE42"
ignored "GGML_STATIC"
bool "GGML_SYCL"
string "GGML_SYCL_DEVICE_ARCH"
bool "GGML_SYCL_F16"
bool "GGML_SYCL_GRAPH"
string "GGML_SYCL_TARGET"
bool "GGML_VULKAN"
bool "GGML_VULKAN_CHECK_RESULTS"
bool "GGML_VULKAN_DEBUG"
bool "GGML_VULKAN_MEMORY_DEBUG"
bool "GGML_VULKAN_PERF"
ignored "GGML_VULKAN_RUN_TESTS"
filepath "GGML_VULKAN_SHADERS_GEN_TOOLCHAIN"
bool "GGML_VULKAN_SHADER_DEBUG_INFO"
pending "GGML_VULKAN_VALIDATE"
bool "GGML_VXE"
filepath "GIT_EXE"
filepath "MATH_LIBRARY"
filepath "METALKIT_FRAMEWORK"
filepath "METAL_FRAMEWORK"
bool "WHISPER_ALL_WARNINGS"
bool "WHISPER_ALL_WARNINGS_3RD_PARTY"
ignored "WHISPER_BIN_INSTALL_DIR"
ignored "WHISPER_BUILD_EXAMPLES"
ignored "WHISPER_BUILD_SERVER"
ignored"WHISPER_BUILD_TESTS"
bool "WHISPER_CCACHE"
bool "WHISPER_COREML"
bool "WHISPER_COREML_ALLOW_FALLBACK"
ignored "WHISPER_CURL"
bool "WHISPER_FATAL_WARNINGS"
ignored "WHISPER_FFMPEG"
ignored "WHISPER_INCLUDE_INSTALL_DIR"
ignored "WHISPER_LIB_INSTALL_DIR"
bool "WHISPER_OPENVINO"
bool "WHISPER_SANITIZE_ADDRESS"
bool "WHISPER_SANITIZE_THREAD"
bool "WHISPER_SANITIZE_UNDEFINED"
ignored "WHISPER_SDL2"
pending "WHISPER_USE_SYSTEM_GGML"
end
def option_name(name)
name.downcase.gsub("_", "-")
end
def bool(name)
option = option_name(name)
value = enable_config(option)
@options[name] = [:bool, value]
end
def string(name, type=:string)
option = "--#{option_name(name)}"
value = arg_config(option)
raise "String expected for #{option}" if value == true || value&.empty?
@options[name] = [type, value]
end
def path(name)
string(name, :path)
end
def filepath(name)
string(name, :filepath)
end
def pending(name)
@pending_options << name
end
def ignored(name)
@ignored_options << name
end
end

View File

@ -19,7 +19,6 @@ typedef struct {
bool diarize;
ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_params;

View File

@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
extern VALUE cParams;
@ -63,8 +63,6 @@ static ID id_new_segment_callback;
static ID id_new_segment_callback_user_data;
static ID id_progress_callback;
static ID id_progress_callback_user_data;
static ID id_encoder_begin_callback;
static ID id_encoder_begin_callback_user_data;
static ID id_abort_callback;
static ID id_abort_callback_user_data;
@ -128,33 +126,6 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
}
}
static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
bool is_aborted = false;
VALUE result;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
is_aborted = true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return !is_aborted;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
is_aborted = true;
}
}
return !is_aborted;
}
static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) {
@ -190,12 +161,6 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
}
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
rwp->encoder_begin_callback_container->context = context;
rwp->params.encoder_begin_callback = encoder_begin_callback;
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback;
@ -208,7 +173,6 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
{
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
}
@ -234,7 +198,6 @@ ruby_whisper_params_allocate(VALUE klass)
rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
@ -886,57 +849,6 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
rwp->progress_callback_container->user_data = value;
return value;
}
static VALUE
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->callback;
}
/*
* Sets encoder begin callback, called when the encoder starts.
*
* params.encoder_begin_callback = ->(context, _, user_data) {
* # ...
* }
*
* call-seq:
* encoder_begin_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of encoder begin callback.
*
* call-seq:
* encoder_begin_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->user_data = value;
return value;
}
static VALUE
ruby_whisper_params_get_abort_callback(VALUE self)
{
@ -1006,7 +918,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
return self;
}
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values);
rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
Data_Get_Struct(self, ruby_whisper_params, rwp);
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
@ -1046,8 +958,6 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(new_segment_callback_user_data)
SET_PARAM_IF_SAME(progress_callback)
SET_PARAM_IF_SAME(progress_callback_user_data)
SET_PARAM_IF_SAME(encoder_begin_callback)
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data)
}
@ -1098,26 +1008,6 @@ ruby_whisper_params_on_progress(VALUE self)
return Qnil;
}
/*
* Hook called when the encoder starts.
*
* whisper.on_encoder_begin do
* # ...
* end
*
* call-seq:
* on_encoder_begin { ... }
*/
static VALUE
ruby_whisper_params_on_encoder_begin(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
return Qnil;
}
/*
* Call block to determine whether abort or not. Return +true+ when you want to abort.
*
@ -1178,13 +1068,10 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(encoder_begin_callback, 28)
DEFINE_PARAM(encoder_begin_callback_user_data, 29)
DEFINE_PARAM(abort_callback, 30)
DEFINE_PARAM(abort_callback_user_data, 31)
DEFINE_PARAM(abort_callback, 28)
DEFINE_PARAM(abort_callback_user_data, 29)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
}

View File

@ -50,16 +50,15 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
// Commented out because it is work in progress
// {
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data;
// return !is_aborted;
// };
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
register_callbacks(rwp, &self);

View File

@ -1,8 +0,0 @@
set(GRAPHVIZ_EXECUTABLES FALSE)
set(GRAPHVIZ_STATIC_LIBS TRUE)
set(GRAPHVIZ_SHARED_LIBS FALSE)
set(GRAPHVIZ_MODULE_LIBS FALSE)
set(GRAPHVIZ_INTERFACE_LIBS FALSE)
set(GRAPHVIZ_OBJECT_LIBS FALSE)
set(GRAPHVIZ_UNKNOWN_LIBS FALSE)
set(GRAPHVIZ_GENERATE_DEPENDERS FALSE)

View File

@ -1,34 +1,6 @@
ignored_dirs = %w[
.devops
examples/wchess/wchess.wasm
examples/whisper.android
examples/whisper.android.java
examples/whisper.objc
examples/whisper.swiftui
grammars
models
samples
scripts
]
ignored_files = %w[
AUTHORS
Makefile
README.md
README_sycl.md
.gitignore
.gitmodules
whisper.nvim
twitch.sh
yt-wsp.sh
]
require "yaml"
EXTSOURCES =
`git ls-files -z ../..`.split("\x0")
.select {|file|
basename = File.basename(file)
ignored_dirs.all? {|dir| !file.start_with?("../../#{dir}")} &&
!ignored_files.include?(basename) &&
(file.start_with?("../..") || file.start_with?("../javascript")) &&
(!file.start_with?("../../.github/") || basename == "bindings-ruby.yml")
}
sources = `git ls-files -z ../..`.split("\x0")
paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
paths.delete "bindings/ruby/**"
EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources

View File

@ -34,7 +34,7 @@ module Whisper
when /darwin/
Pathname(Dir.home)/"Library/Caches"
else
ENV.key?("XDG_CACHE_HOME") ? Pathname(ENV["XDG_CACHE_HOME"]) : Pathname(Dir.home)/".cache"
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
end
base/"whisper.cpp"
end
@ -53,10 +53,8 @@ module Whisper
http.request request do |response|
case response
when Net::HTTPNotModified
# noop
# noop
when Net::HTTPOK
return if !response.key?("last-modified") && cache_path.exist?
download response
when Net::HTTPRedirection
request URI(response["location"]), headers
@ -70,7 +68,7 @@ module Whisper
rescue => err
if cache_path.exist?
warn err
# Use cache file
# Use cache file
else
raise
end

View File

@ -7,7 +7,6 @@ module Whisper
type log_callback = ^(Integer level, String message, Object user_data) -> void
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
LOG_LEVEL_NONE: Integer
@ -24,20 +23,9 @@ module Whisper
def self.log_set: (log_callback, Object? user_data) -> log_callback
class Context
def self.new: (path | ::URI::HTTP) -> instance
# transcribe a single file
# can emit to a block results
#
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
#
def self.new: (string | _ToPath | ::URI::HTTP) -> instance
def transcribe: (string, Params) -> self
| (string, Params) { (String) -> void } -> self
def model_n_vocab: () -> Integer
def model_n_audio_ctx: () -> Integer
def model_n_audio_state: () -> Integer
@ -46,72 +34,19 @@ module Whisper
def model_n_mels: () -> Integer
def model_ftype: () -> Integer
def model_type: () -> String
# Yields each Whisper::Segment:
#
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
#
# Returns an Enumerator if no block given:
#
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
def model: () -> Model
def full_get_segment: (Integer nth) -> Segment
def full_n_segments: () -> Integer
# Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
#
def full_lang_id: () -> Integer
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer) -> Integer
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer) -> Integer
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
#
# full_get_segment_speacker_turn_next(3) # => true
#
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
# Text of a segment indexed by +segment_index+.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer) -> String
def full_get_segment_no_speech_prob: (Integer) -> Float
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
# Not thread safe for same context
# Uses the specified decoding strategy to obtain the text.
#
# The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
#
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self
# Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
# Result is stored in the default state of the context
# Not thread safe if executed in parallel on the same context.
# It seems this approach can offer some speedup in some cases.
# However, the transcription accuracy can be worse at the beginning and end of each chunk.
#
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
@ -147,223 +82,71 @@ module Whisper
?new_segment_callback_user_data: Object,
?progress_callback: progress_callback,
?progress_callback_user_data: Object,
?encoder_begin_callback: encoder_begin_callback,
?encoder_begin_callback_user_data: Object,
?abort_callback: abort_callback,
?abort_callback_user_data: Object
) -> instance
# params.language = "auto" | "en", etc...
#
def language=: (String) -> String # TODO: Enumerate lang names
def language: () -> String
def translate=: (boolish) -> boolish
def translate: () -> (true | false)
def no_context=: (boolish) -> boolish
# If true, does not use past transcription (if any) as initial prompt for the decoder.
#
def no_context: () -> (true | false)
def single_segment=: (boolish) -> boolish
# If true, forces single segment output (useful for streaming).
#
def single_segment: () -> (true | false)
def print_special=: (boolish) -> boolish
# If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
#
def print_special: () -> (true | false)
def print_progress=: (boolish) -> boolish
# If true, prints progress information.
#
def print_progress: () -> (true | false)
def print_realtime=: (boolish) -> boolish
# If true, prints results from within whisper.cpp. (avoid it, use callback instead)
#
def print_realtime: () -> (true | false)
# If true, prints timestamps for each text segment when printing realtime.
#
def print_timestamps=: (boolish) -> boolish
def print_timestamps: () -> (true | false)
def suppress_blank=: (boolish) -> boolish
# If true, suppresses blank outputs.
#
def suppress_blank: () -> (true | false)
def suppress_nst=: (boolish) -> boolish
# If true, suppresses non-speech-tokens.
#
def suppress_nst: () -> (true | false)
def token_timestamps=: (boolish) -> boolish
# If true, enables token-level timestamps.
#
def token_timestamps: () -> (true | false)
def split_on_word=: (boolish) -> boolish
# If true, split on word rather than on token (when used with max_len).
#
def split_on_word: () -> (true | false)
def initial_prompt=: (_ToS) -> _ToS
# Tokens to provide to the whisper decoder as initial prompt
# these are prepended to any existing text context from a previous call
# use whisper_tokenize() to convert text to tokens.
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
#
def initial_prompt: () -> (String | nil)
def diarize=: (boolish) -> boolish
# If true, enables diarization.
#
def diarize: () -> (true | false)
def offset=: (Integer) -> Integer
# Start offset in ms.
#
def offset: () -> Integer
def duration=: (Integer) -> Integer
# Audio duration to process in ms.
#
def duration: () -> Integer
def max_text_tokens=: (Integer) -> Integer
# Max tokens to use from past text as prompt for the decoder.
#
def max_text_tokens: () -> Integer
def temperature=: (Float) -> Float
def temperature: () -> Float
def max_initial_ts=: (Float) -> Float
# See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
#
def max_initial_ts: () -> Float
def length_penalty=: (Float) -> Float
def length_penalty: () -> Float
def temperature_inc=: (Float) -> Float
def temperature_inc: () -> Float
def entropy_thold=: (Float) -> Float
# Similar to OpenAI's "compression_ratio_threshold"
#
def entropy_thold: () -> Float
def logprob_thold=: (Float) -> Float
def logprob_thold: () -> Float
def no_speech_thold=: (Float) -> Float
def no_speech_thold: () -> Float
# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback: () -> (new_segment_callback | nil)
# Sets user data passed to the last argument of new segment callback.
#
def new_segment_callback_user_data=: (Object) -> Object
def new_segment_callback_user_data: () -> Object
# Sets progress callback, called on each progress update.
#
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
#
# +progress+ is an Integer between 0 and 100.
#
def progress_callback=: (progress_callback) -> progress_callback
def progress_callback: () -> (progress_callback | nil)
# Sets user data passed to the last argument of progress callback.
#
def progress_callback_user_data=: (Object) -> Object
def progress_callback_user_data: () -> Object
# Sets encoder begin callback, called when the encoder starts.
#
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
def encoder_begin_callback: () -> (encoder_begin_callback | nil)
# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object) -> Object
def encoder_begin_callback_user_data: () -> Object
# Sets abort callback, called to check if the process should be aborted.
#
# params.abort_callback = ->(user_data) {
# # ...
# }
#
#
def abort_callback=: (abort_callback) -> abort_callback
def abort_callback: () -> (abort_callback | nil)
# Sets user data passed to the last argument of abort callback.
#
def abort_callback_user_data=: (Object) -> Object
def abort_callback_user_data: () -> Object
# Hook called on new segment. Yields each Whisper::Segment.
#
# whisper.on_new_segment do |segment|
# # ...
# end
#
def on_new_segment: { (Segment) -> void } -> void
# Hook called on progress update. Yields each progress Integer between 0 and 100.
#
def on_progress: { (Integer progress) -> void } -> void
# Hook called on encoder starts.
#
def on_encoder_begin: { () -> void } -> void
# Call block to determine whether abort or not. Return +true+ when you want to abort.
#
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# end
# end
#
def abort_on: { (Object user_data) -> boolish } -> void
end
@ -384,24 +167,16 @@ module Whisper
def type: () -> String
class URI
def self.new: (string | ::URI::HTTP) -> instance
def self.new: (string | ::URI::HTTP) -> self
def to_path: -> String
def clear_cache: -> void
end
end
class Segment
# Start time in milliseconds.
#
def start_time: () -> Integer
# End time in milliseconds.
#
def end_time: () -> Integer
# Whether the next segment is predicted as a speaker turn.
def speaker_next_turn?: () -> (true | false)
def text: () -> String
def no_speech_prob: () -> Float
end

View File

@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self
def whisper
return @whisper if @whisper
attr_reader :whisper
def startup
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@ -21,15 +21,4 @@ class TestBase < Test::Unit::TestCase
def whisper
self.class.whisper
end
module BuildOptions
load "ext/options.rb", self
Options.include self
def enable_config(name)
end
def arg_config(name)
end
end
end

View File

@ -25,7 +25,7 @@ class TestCallback < TestBase
assert start_time >= 0
assert_kind_of Integer, end_time
assert end_time > 0
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) if i_segment == 0
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
end
}
@ -111,48 +111,6 @@ class TestCallback < TestBase
assert_equal 100, last
end
def test_encoder_begin_callback
i = 0
@params.encoder_begin_callback = ->(context, state, user_data) {
i += 1
}
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_encoder_begin_callback_abort
logs = []
Whisper.log_set -> (level, buffer, user_data) {
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
}, logs
@params.encoder_begin_callback = ->(context, state, user_data) {
return false
}
@whisper.transcribe(@audio, @params)
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end
def test_encoder_begin_callback_user_data
udata = Object.new
@params.encoder_begin_callback_user_data = udata
yielded = nil
@params.encoder_begin_callback = ->(context, state, user_data) {
yielded = user_data
}
@whisper.transcribe(@audio, @params)
assert_same udata, yielded
end
def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_abort_callback
i = 0
@params.abort_callback = ->(user_data) {
@ -187,9 +145,9 @@ class TestCallback < TestBase
def test_abort_on
do_abort = false
_aborted_from_callback = false
aborted_from_callback = false
@params.on_new_segment do |segment|
do_abort = true if segment.text.match?(/ask/)
do_abort = true if segment.text.match? /ask/
end
i = 0
@params.abort_on do

View File

@ -4,7 +4,7 @@ class TestError < TestBase
def test_error
error = Whisper::Error.new(-2)
assert_equal "failed to compute log mel spectrogram", error.message
assert_equal(-2, error.code)
assert_equal -2, error.code
end
def test_unknown_error
@ -14,7 +14,7 @@ class TestError < TestBase
def test_non_int_code
assert_raise TypeError do
_error = Whisper::Error.new("non int")
error = Whisper::Error.new("non int")
end
end
end

View File

@ -21,26 +21,11 @@ class TestPackage < TestBase
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
filename = match_data[1]
version = match_data[2]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
assert_installed dir, version
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
end
end
private
def assert_installed(dir, version)
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
end
end
def test_build_options
options = BuildOptions::Options.new
assert_empty options.missing_options
unless ENV["CI"]
assert_empty options.extra_options
end
end
end

View File

@ -162,7 +162,7 @@ class TestParams < TestBase
end
def test_length_penalty
assert_equal(-1.0, @params.length_penalty)
assert_equal -1.0, @params.length_penalty
@params.length_penalty = 0.5
assert_equal 0.5, @params.length_penalty
end
@ -180,9 +180,9 @@ class TestParams < TestBase
end
def test_logprob_thold
assert_in_delta(-1.0, @params.logprob_thold)
assert_in_delta -1.0, @params.logprob_thold
@params.logprob_thold = -0.5
assert_in_delta(-0.5, @params.logprob_thold)
assert_in_delta -0.5, @params.logprob_thold
end
def test_no_speech_thold

View File

@ -49,13 +49,13 @@ class TestSegment < TestBase
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end
index += 1
end
whisper.transcribe(AUDIO, params)
assert_equal 0, seg.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
end
def test_on_new_segment_twice

View File

@ -16,7 +16,7 @@ class TestWhisper < TestBase
params.print_timestamps = false
@whisper.transcribe(AUDIO, params) {|text|
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, text
}
end
@ -32,7 +32,7 @@ class TestWhisper < TestBase
def test_full_get_segment
segment = whisper.full_get_segment(0)
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end
def test_full_get_segment_t0
@ -59,7 +59,7 @@ class TestWhisper < TestBase
end
def test_full_get_segment_text
assert_match(/ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0))
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end
def test_full_get_segment_no_speech_prob
@ -134,14 +134,14 @@ class TestWhisper < TestBase
@whisper.full(@params, @samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_without_length
@whisper.full(@params, @samples)
assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_enumerator
@ -149,7 +149,7 @@ class TestWhisper < TestBase
@whisper.full(@params, samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_enumerator_without_length
@ -171,28 +171,26 @@ class TestWhisper < TestBase
@whisper.full(@params, samples)
assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text)
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_parallel
nprocessors = 2
@whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
@whisper.full_parallel(@params, @samples, @samples.length, Etc.nprocessors)
assert_equal nprocessors, @whisper.full_n_segments
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text)
assert_match(/for your country/i, text)
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_with_memory_view
nprocessors = 2
samples = JFKReader.new(AUDIO)
@whisper.full_parallel(@params, samples, nil, nprocessors)
@whisper.full_parallel(@params, samples, nil, Etc.nprocessors)
assert_equal nprocessors, @whisper.full_n_segments
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text)
assert_match(/for your country/i, text)
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_length_and_n_processors
@ -200,18 +198,17 @@ class TestWhisper < TestBase
assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text)
assert_match(/for your country/i, text)
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_length
nprocessors = 2
@whisper.full_parallel(@params, @samples, nil, nprocessors)
@whisper.full_parallel(@params, @samples, nil, Etc.nprocessors)
assert_equal nprocessors, @whisper.full_n_segments
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text)
assert_match(/for your country/i, text)
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_n_processors
@ -219,8 +216,8 @@ class TestWhisper < TestBase
assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text)
assert_match(/for your country/i, text)
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
end
end

View File

@ -3,8 +3,8 @@ require_relative "extsources"
Gem::Specification.new do |s|
s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.2'
s.date = '2025-05-01'
s.version = '1.3.1'
s.date = '2024-12-19'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md']
@ -15,8 +15,7 @@ Gem::Specification.new do |s|
if s.extra_rdoc_files.include?(basename)
basename
else
file.sub("../..", "ext/sources")
.sub("../javascript", "ext/sources/bindings/javascript")
file.sub("../..", "ext")
end
}
@ -27,7 +26,7 @@ Gem::Specification.new do |s|
s.required_ruby_version = '>= 3.1.0'
#### Documentation and testing.
s.homepage = 'https://github.com/ggml-org/whisper.cpp'
s.homepage = 'https://github.com/ggerganov/whisper.cpp'
s.rdoc_options = ['--main', 'README.md']

View File

@ -41,11 +41,6 @@ COMMON_CMAKE_ARGS=(
-DGGML_OPENMP=${GGML_OPENMP}
)
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
check_required_tool() {
local tool=$1
local install_message=$2
@ -113,7 +108,7 @@ setup_framework_structure() {
fi
# Copy all required headers (common for all platforms)
cp include/whisper.h ${header_path}
cp include/whisper.h ${header_path}
cp ggml/include/ggml.h ${header_path}
cp ggml/include/ggml-alloc.h ${header_path}
cp ggml/include/ggml-backend.h ${header_path}
@ -250,16 +245,9 @@ combine_static_libraries() {
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
)
if [[ "$platform" == "macos" || "$platform" == "ios" ]]; then
echo "Adding libwhisper.coreml library to the build."
libs+=(
"${base_dir}/${build_dir}/src/${release_dir}/libwhisper.coreml.a"
)
fi
# Create temporary directory for processing
local temp_dir="${base_dir}/${build_dir}/temp"
echo "Creating temporary directory: ${temp_dir}"
mkdir -p "${temp_dir}"
# Since we have multiple architectures libtool will find object files that do not
@ -271,7 +259,6 @@ combine_static_libraries() {
local archs=""
local min_version_flag=""
local install_name=""
local frameworks="-framework Foundation -framework Metal -framework Accelerate"
case "$platform" in
"ios")
@ -285,14 +272,12 @@ combine_static_libraries() {
min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}"
fi
install_name="@rpath/whisper.framework/whisper"
frameworks+=" -framework CoreML"
;;
"macos")
sdk="macosx"
archs="arm64 x86_64"
min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}"
install_name="@rpath/whisper.framework/Versions/Current/whisper"
frameworks+=" -framework CoreML"
;;
"visionos")
if [[ "$is_simulator" == "true" ]]; then
@ -334,34 +319,27 @@ combine_static_libraries() {
$arch_flags \
$min_version_flag \
-Wl,-force_load,"${temp_dir}/combined.a" \
$frameworks \
-framework Foundation -framework Metal -framework Accelerate \
-install_name "$install_name" \
-o "${base_dir}/${output_lib}"
# Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then
if command -v xcrun vtool &>/dev/null; then
if command -v vtool &>/dev/null; then
case "$platform" in
"ios")
echo "Marking binary as a framework binary for iOS..."
xcrun vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;;
"visionos")
echo "Marking binary as a framework binary for visionOS..."
if [[ "$MAJOR_VERSION" -gt 16 ]] || [[ "$MAJOR_VERSION" -eq 16 && "$MINOR_VERSION" -gt 2 ]]; then
echo "Xcode version greater than 16.2, using visionOS."
VISION_OS_BUILD_VERSION="visionos"
else
echo "Xcode version less than or equal to 16.2, using xros."
VISION_OS_BUILD_VERSION="xros"
fi
xcrun vtool -set-build-version ${VISION_OS_BUILD_VERSION} ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;;
"tvos")
echo "Marking binary as a framework binary for tvOS..."
xcrun vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;;
esac
@ -421,8 +399,6 @@ cmake -B build-ios-sim -G Xcode \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DWHISPER_COREML="ON" \
-DWHISPER_COREML_ALLOW_FALLBACK="ON" \
-S .
cmake --build build-ios-sim --config Release -- -quiet
@ -435,8 +411,6 @@ cmake -B build-ios-device -G Xcode \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DWHISPER_COREML="ON" \
-DWHISPER_COREML_ALLOW_FALLBACK="ON" \
-S .
cmake --build build-ios-device --config Release -- -quiet
@ -447,8 +421,6 @@ cmake -B build-macos -G Xcode \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DWHISPER_COREML="ON" \
-DWHISPER_COREML_ALLOW_FALLBACK="ON" \
-S .
cmake --build build-macos --config Release -- -quiet
@ -460,8 +432,8 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-visionos --config Release -- -quiet
@ -473,8 +445,8 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-visionos-sim --config Release -- -quiet

View File

@ -10,8 +10,6 @@
# # with CUDA support
# GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
#
# # with SYCL support
# GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
if [ -z "$2" ]; then
echo "usage: $0 <output-dir> <mnt-dir>"
@ -326,9 +324,8 @@ ret=0
for model in "${MODELS[@]}"; do
test $ret -eq 0 && gg_download_model ${model}
done
if [ -z ${GG_BUILD_SYCL}]; then
test $ret -eq 0 && gg_run ctest debug
fi
test $ret -eq 0 && gg_run ctest debug
test $ret -eq 0 && gg_run ctest release
test $ret -eq 0 && gg_run bench

View File

@ -18,13 +18,6 @@ const whisperParamsMock = {
translate: true,
no_timestamps: false,
audio_ctx: 0,
max_len: 0,
prompt: "",
print_progress: false,
progress_callback: (progress) => {
console.log(`Progress: ${progress}`);
},
max_context: -1
};
describe("Run whisper.node", () => {

View File

@ -128,227 +128,192 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
void cb_log_disable(enum ggml_log_level, const char *, void *) {}
class ProgressWorker : public Napi::AsyncWorker {
public:
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
: Napi::AsyncWorker(callback), params(params), env(env) {
// Create thread-safe function
if (!progress_callback.IsEmpty()) {
tsfn = Napi::ThreadSafeFunction::New(
env,
progress_callback,
"Progress Callback",
0,
1
);
}
int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}
~ProgressWorker() {
if (tsfn) {
// Make sure to release the thread-safe function on destruction
tsfn.Release();
}
if (params.fname_inp.empty() && params.pcmf32.empty()) {
fprintf(stderr, "error: no input files or audio buffer specified\n");
return 2;
}
void Execute() override {
// Use custom run function with progress callback support
run_with_progress(params, result);
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
exit(0);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// if params.pcmf32 is provided, set params.fname_inp to "buffer"
// this is simpler than further modifications in the code
if (!params.pcmf32.empty()) {
fprintf(stderr, "info: using audio buffer as input\n");
params.fname_inp.clear();
params.fname_inp.emplace_back("buffer");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// read the input audio file if params.pcmf32 is not provided
if (params.pcmf32.empty()) {
if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
continue;
}
} else {
pcmf32 = params.pcmf32;
}
// print system information
if (!params.no_prints) {
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
if (!params.no_prints) {
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1,
params.audio_ctx);
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.initial_prompt = params.prompt.c_str();
wparams.no_timestamps = params.no_timestamps;
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
// Progress callback function - using thread-safe function
void OnProgress(int progress) {
if (tsfn) {
// Use thread-safe function to call JavaScript callback
auto callback = [progress](Napi::Env env, Napi::Function jsCallback) {
jsCallback.Call({Napi::Number::New(env, progress)});
};
tsfn.BlockingCall(callback);
}
const int n_segments = whisper_full_n_segments(ctx);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}
class Worker : public Napi::AsyncWorker {
public:
Worker(Napi::Function& callback, whisper_params params)
: Napi::AsyncWorker(callback), params(params) {}
void Execute() override {
run(params, result);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
private:
whisper_params params;
std::vector<std::vector<std::string>> result;
Napi::Env env;
Napi::ThreadSafeFunction tsfn;
// Custom run function with progress callback support
int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}
if (params.fname_inp.empty() && params.pcmf32.empty()) {
fprintf(stderr, "error: no input files or audio buffer specified\n");
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
exit(0);
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// If params.pcmf32 provides, set params.fname_inp as "buffer"
if (!params.pcmf32.empty()) {
fprintf(stderr, "info: using audio buffer as input\n");
params.fname_inp.clear();
params.fname_inp.emplace_back("buffer");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// If params.pcmf32 is empty, read input audio file
if (params.pcmf32.empty()) {
if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
continue;
}
} else {
pcmf32 = params.pcmf32;
}
// Print system info
if (!params.no_prints) {
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// Print processing info
if (!params.no_prints) {
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1,
params.audio_ctx);
fprintf(stderr, "\n");
}
// Run inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.initial_prompt = params.prompt.c_str();
wparams.no_timestamps = params.no_timestamps;
whisper_print_user_data user_data = { &params, &pcmf32s };
// This callback is called for each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// Set progress callback
wparams.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
ProgressWorker* worker = static_cast<ProgressWorker*>(user_data);
worker->OnProgress(progress);
};
wparams.progress_callback_user_data = this;
// Abort mechanism example
{
static bool is_aborted = false; // Note: this should be atomic to avoid data races
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
}
}
const int n_segments = whisper_full_n_segments(ctx);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}
whisper_params params;
std::vector<std::vector<std::string>> result;
};
Napi::Value whisper(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() <= 0 || !info[0].IsObject()) {
@ -367,29 +332,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
// Add support for max_context
int32_t max_context = -1;
if (whisper_params.Has("max_context") && whisper_params.Get("max_context").IsNumber()) {
max_context = whisper_params.Get("max_context").As<Napi::Number>();
}
// support prompt
std::string prompt = "";
if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
prompt = whisper_params.Get("prompt").As<Napi::String>();
}
// Add support for print_progress
bool print_progress = false;
if (whisper_params.Has("print_progress")) {
print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
}
// Add support for progress_callback
Napi::Function progress_callback;
if (whisper_params.Has("progress_callback") && whisper_params.Get("progress_callback").IsFunction()) {
progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
}
Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
std::vector<float> pcmf32_vec;
@ -413,13 +355,9 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.pcmf32 = pcmf32_vec;
params.comma_in_time = comma_in_time;
params.max_len = max_len;
params.max_context = max_context;
params.print_progress = print_progress;
params.prompt = prompt;
Napi::Function callback = info[1].As<Napi::Function>();
// Create a new Worker class with progress callback support
ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env);
Worker* worker = new Worker(callback, params);
worker->Queue();
return env.Undefined();
}

View File

@ -19,9 +19,6 @@ const whisperParams = {
no_timestamps: false,
audio_ctx: 0,
max_len: 0,
progress_callback: (progress) => {
console.log(`progress: ${progress}%`);
}
};
const arguments = process.argv.slice(2);

View File

@ -2,7 +2,7 @@
Benchmark the performance of whisper.cpp in the browser using WebAssembly
Link: https://ggerganov.github.io/whisper.cpp/bench.wasm
Link: https://whisper.ggerganov.com/bench/
Terminal version: [examples/bench](/examples/bench)
@ -15,17 +15,7 @@ cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
```
The example can then be started by running a local HTTP server:
```console
python3 examples/server.py
```
And then opening a browser to the following URL:
http://localhost:8000/bench.wasm
To run the example in a different server, you need to copy the following files
to the server's HTTP path:
```
# copy the produced page to your HTTP path
cp bin/bench.wasm/* /path/to/html/
cp bin/libbench.worker.js /path/to/html/

View File

@ -24,8 +24,6 @@
overflow-x: scroll;
}
</style>
<script src="../coi-serviceworker.js"></script>
<link rel="icon" href="data:,">
</head>
<body>
<div id="main-container">
@ -38,10 +36,11 @@
<br><br>
<b>More examples:</b>
<a href="../">main</a> |
<a href="../bench.wasm/">bench</a> |
<a href="../stream.wasm">stream</a> |
<a href="../command.wasm/">command</a> |
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>

View File

@ -4,7 +4,7 @@ A very basic tool for benchmarking the inference performance on your device. The
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
of the performance of the model for various setups.
Benchmark results are tracked in the following Github issue: https://github.com/ggml-org/whisper.cpp/issues/89
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
```bash
# run the bench too on the small.en model using 4 threads
@ -40,7 +40,7 @@ system_info: n_threads = 4 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WA
If you wish, you can submit these results here:
https://github.com/ggml-org/whisper.cpp/issues/89
https://github.com/ggerganov/whisper.cpp/issues/89
Please include the following information:

View File

@ -6,8 +6,7 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
```
./build/bin/whisper-cli -h
usage: ./build/bin/whisper-cli [options] file0 file1 ...
supported audio formats: flac, mp3, ogg, wav
usage: ./build-pkg/bin/whisper-cli [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
@ -25,7 +24,6 @@ options:
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-nth N, --no-speech-thold N [0.60 ] no speech threshold
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
@ -52,13 +50,12 @@ options:
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input audio file path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
-fa, --flash-attn [false ] flash attention
-sns, --suppress-nst [false ] suppress non-speech tokens
--suppress-regex REGEX [ ] regular expression matching tokens to suppress
--grammar GRAMMAR [ ] GBNF grammar to guide decoding
--grammar-rule RULE [ ] top-level GBNF grammar rule name

View File

@ -13,12 +13,14 @@
#include <cstring>
#if defined(_WIN32)
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// helper function to replace substrings
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
@ -375,7 +377,15 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
}
}
static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
static bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
@ -390,9 +400,19 @@ static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const
fout << speaker << text << "\n";
}
return true;
}
static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
static bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx);
@ -412,9 +432,19 @@ static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << speaker << text << "\n\n";
}
return true;
}
static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
static bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
@ -431,6 +461,8 @@ static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
fout << speaker << text << "\n\n";
}
return true;
}
static char * escape_double_quotes_and_backslashes(const char * str) {
@ -496,7 +528,15 @@ static char * escape_double_quotes_in_csv(const char * str) {
return escaped;
}
static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
static bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,";
if (params.diarize && pcmf32s.size() == 2)
@ -519,9 +559,14 @@ static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const
}
fout << "\"" << text_escaped << "\"\n";
}
return true;
}
static void output_score(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
static bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
// fprintf(stderr,"segments: %d\n",n_segments);
for (int i = 0; i < n_segments; ++i) {
@ -534,14 +579,16 @@ static void output_score(struct whisper_context * ctx, std::ofstream & fout, con
// fprintf(stderr,"token: %s %f\n",token,probability);
}
}
return true;
}
static void output_json(
static bool output_json(
struct whisper_context * ctx,
std::ofstream & fout,
const char * fname,
const whisper_params & params,
std::vector<std::vector<float>> pcmf32s) {
const bool full = params.output_jsn_full;
std::vector<std::vector<float>> pcmf32s,
bool full) {
std::ofstream fout(fname);
int indent = 0;
auto doindent = [&]() {
@ -621,6 +668,12 @@ static void output_json(
end_obj(end);
};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj(nullptr);
value_s("systeminfo", whisper_print_system_info(), false);
start_obj("model");
@ -694,12 +747,17 @@ static void output_json(
end_arr(true);
end_obj(true);
return true;
}
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s, const char * fname_inp, float t_sec, const char * fname_out) {
static bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
static const char * font = params.font_path.c_str();
std::ifstream fin(font);
@ -815,12 +873,20 @@ static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const
fout.close();
fprintf(stderr, "# %s: run 'source %s' to generate karaoke video\n", __func__, fname_out);
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
return true;
}
static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
static bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "[by:whisper.cpp]\n";
const int n_segments = whisper_full_n_segments(ctx);
@ -848,6 +914,8 @@ static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
}
return true;
}
@ -996,55 +1064,8 @@ int main(int argc, char ** argv) {
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto & fname_inp = params.fname_inp[f];
struct fout_factory {
std::string fname_out;
const size_t basename_length;
const bool is_stdout;
bool used_stdout;
decltype(whisper_print_segment_callback) * const print_segment_callback;
std::ofstream fout;
fout_factory (const std::string & fname_out_, const std::string & fname_inp, whisper_params & params) :
fname_out{!fname_out_.empty() ? fname_out_ : fname_inp},
basename_length{fname_out.size()},
is_stdout{fname_out == "-"},
used_stdout{},
print_segment_callback{is_stdout ? nullptr : whisper_print_segment_callback} {
if (!print_segment_callback) {
params.print_progress = false;
}
}
bool open(const char * ext, const char * function) {
if (is_stdout) {
if (used_stdout) {
fprintf(stderr, "warning: Not appending multiple file formats to stdout\n");
return false;
}
used_stdout = true;
#ifdef _WIN32
fout = std::ofstream{"CON"};
#else
fout = std::ofstream{"/dev/stdout"};
#endif
// Not using fprintf stderr here because it might equal stdout
// Also assuming /dev is mounted
return true;
}
fname_out.resize(basename_length);
fname_out += ext;
fout = std::ofstream{fname_out};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", function, fname_out.c_str());
return true;
}
} fout_factory{f < (int) params.fname_out.size() ? params.fname_out[f] : "", fname_inp, params};
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
@ -1149,7 +1170,7 @@ int main(int argc, char ** argv) {
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = fout_factory.print_segment_callback;
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
@ -1191,26 +1212,54 @@ int main(int argc, char ** argv) {
// output stuff
{
// macros to stringify function name
#define output_func(func, ext, param, ...) if (param && fout_factory.open(ext, #func)) {\
func(ctx, fout_factory.fout, params, __VA_ARGS__); \
}
#define output_ext(ext, ...) output_func(output_##ext, "." #ext, params.output_##ext, __VA_ARGS__)
printf("\n");
output_ext(txt, pcmf32s);
output_ext(vtt, pcmf32s);
output_ext(srt, pcmf32s);
output_ext(wts, pcmf32s, fname_inp.c_str(), float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, fout_factory.fname_out.c_str());
output_ext(csv, pcmf32s);
output_func(output_json, ".json", params.output_jsn, pcmf32s);
output_ext(lrc, pcmf32s);
output_func(output_score, ".score.txt", params.log_score, pcmf32s);
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
}
#undef output_ext
#undef output_func
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
}
if (fout_factory.is_stdout && !fout_factory.used_stdout) {
fprintf(stderr, "warning: '--output-file -' used without any other '--output-*'");
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
}
// output to score file
if (params.log_score) {
const auto fname_score = fname_out + ".score.txt";
output_score(ctx, fname_score.c_str(), params, pcmf32s);
}
}
}

View File

@ -1,146 +0,0 @@
/*! coi-serviceworker v0.1.7 - Guido Zuidhof and contributors, licensed under MIT */
let coepCredentialless = false;
if (typeof window === 'undefined') {
self.addEventListener("install", () => self.skipWaiting());
self.addEventListener("activate", (event) => event.waitUntil(self.clients.claim()));
self.addEventListener("message", (ev) => {
if (!ev.data) {
return;
} else if (ev.data.type === "deregister") {
self.registration
.unregister()
.then(() => {
return self.clients.matchAll();
})
.then(clients => {
clients.forEach((client) => client.navigate(client.url));
});
} else if (ev.data.type === "coepCredentialless") {
coepCredentialless = ev.data.value;
}
});
self.addEventListener("fetch", function (event) {
const r = event.request;
if (r.cache === "only-if-cached" && r.mode !== "same-origin") {
return;
}
const request = (coepCredentialless && r.mode === "no-cors")
? new Request(r, {
credentials: "omit",
})
: r;
event.respondWith(
fetch(request)
.then((response) => {
if (response.status === 0) {
return response;
}
const newHeaders = new Headers(response.headers);
newHeaders.set("Cross-Origin-Embedder-Policy",
coepCredentialless ? "credentialless" : "require-corp"
);
if (!coepCredentialless) {
newHeaders.set("Cross-Origin-Resource-Policy", "cross-origin");
}
newHeaders.set("Cross-Origin-Opener-Policy", "same-origin");
return new Response(response.body, {
status: response.status,
statusText: response.statusText,
headers: newHeaders,
});
})
.catch((e) => console.error(e))
);
});
} else {
(() => {
const reloadedBySelf = window.sessionStorage.getItem("coiReloadedBySelf");
window.sessionStorage.removeItem("coiReloadedBySelf");
const coepDegrading = (reloadedBySelf == "coepdegrade");
// You can customize the behavior of this script through a global `coi` variable.
const coi = {
shouldRegister: () => !reloadedBySelf,
shouldDeregister: () => false,
coepCredentialless: () => true,
coepDegrade: () => true,
doReload: () => window.location.reload(),
quiet: false,
...window.coi
};
const n = navigator;
const controlling = n.serviceWorker && n.serviceWorker.controller;
// Record the failure if the page is served by serviceWorker.
if (controlling && !window.crossOriginIsolated) {
window.sessionStorage.setItem("coiCoepHasFailed", "true");
}
const coepHasFailed = window.sessionStorage.getItem("coiCoepHasFailed");
if (controlling) {
// Reload only on the first failure.
const reloadToDegrade = coi.coepDegrade() && !(
coepDegrading || window.crossOriginIsolated
);
n.serviceWorker.controller.postMessage({
type: "coepCredentialless",
value: (reloadToDegrade || coepHasFailed && coi.coepDegrade())
? false
: coi.coepCredentialless(),
});
if (reloadToDegrade) {
!coi.quiet && console.log("Reloading page to degrade COEP.");
window.sessionStorage.setItem("coiReloadedBySelf", "coepdegrade");
coi.doReload("coepdegrade");
}
if (coi.shouldDeregister()) {
n.serviceWorker.controller.postMessage({ type: "deregister" });
}
}
// If we're already coi: do nothing. Perhaps it's due to this script doing its job, or COOP/COEP are
// already set from the origin server. Also if the browser has no notion of crossOriginIsolated, just give up here.
if (window.crossOriginIsolated !== false || !coi.shouldRegister()) return;
if (!window.isSecureContext) {
!coi.quiet && console.log("COOP/COEP Service Worker not registered, a secure context is required.");
return;
}
// In some environments (e.g. Firefox private mode) this won't be available
if (!n.serviceWorker) {
!coi.quiet && console.error("COOP/COEP Service Worker not registered, perhaps due to private mode.");
return;
}
n.serviceWorker.register(window.document.currentScript.src).then(
(registration) => {
!coi.quiet && console.log("COOP/COEP Service Worker registered", registration.scope);
registration.addEventListener("updatefound", () => {
!coi.quiet && console.log("Reloading page to make use of updated COOP/COEP Service Worker.");
window.sessionStorage.setItem("coiReloadedBySelf", "updatefound");
coi.doReload();
});
// If the registration is active, but it's not controlling the page
if (registration.active && !n.serviceWorker.controller) {
!coi.quiet && console.log("Reloading page to make use of COOP/COEP Service Worker.");
window.sessionStorage.setItem("coiReloadedBySelf", "notcontrolling");
coi.doReload();
}
},
(err) => {
!coi.quiet && console.error("COOP/COEP Service Worker failed to register:", err);
}
);
})();
}

View File

@ -3,7 +3,7 @@
This is a basic Voice Assistant example that accepts voice commands from the microphone.
It runs in fully in the browser via WebAseembly.
Online demo: https://ggerganov.github.io/whisper.cpp/command.wasm
Online demo: https://whisper.ggerganov.com/command/
Terminal version: [examples/command](/examples/command)
@ -15,18 +15,9 @@ git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j libcommand
```
The example can then be started by running a local HTTP server:
```console
python3 examples/server.py
```
And then opening a browser to the following URL:
http://localhost:8000/command.wasm/
make -j
To run the example in a different server, you need to copy the following files
to the server's HTTP path:
```
# copy the produced page to your HTTP path
cp bin/command.wasm/* /path/to/html/
cp bin/libcommand.worker.js /path/to/html/
```

View File

@ -24,8 +24,6 @@
overflow-x: scroll;
}
</style>
<script src="../coi-serviceworker.js"></script>
<link rel="icon" href="data:,">
</head>
<body>
<div id="main-container">
@ -38,10 +36,11 @@
<br><br>
<b>More examples:</b>
<a href="../">main</a> |
<a href="../bench.wasm/">bench</a> |
<a href="../stream.wasm">stream</a> |
<a href="../command.wasm/">command</a> |
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>

View File

@ -3,7 +3,7 @@
// Speak short text commands to the microphone.
// This program will detect your voice command and convert them to text.
//
// ref: https://github.com/ggml-org/whisper.cpp/issues/171
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
//
#include "common-sdl.h"

View File

@ -26,6 +26,10 @@
#define MINIAUDIO_IMPLEMENTATION
#include "miniaudio.h"
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#ifdef _WIN32
#include <fcntl.h>
#include <io.h>

View File

@ -10,6 +10,10 @@
#include <regex>
#include <sstream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Function to check if the next argument exists
static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
if (i + 1 < argc && argv[i + 1][0] != '-') {
@ -243,6 +247,17 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result;
}
std::string convert_to_utf8(const std::wstring & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(input);
}
std::wstring convert_to_wstring(const std::string & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.from_bytes(input);
}
void gpt_split_words(std::string str, std::vector<std::string>& words) {
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
const std::regex re(pattern);

View File

@ -1,6 +1,4 @@
add_executable(main ./deprecation-warning.cpp)
add_executable(bench ./deprecation-warning.cpp)
if (WHISPER_SDL2)
add_executable(stream ./deprecation-warning.cpp)
add_executable(command ./deprecation-warning.cpp)
endif()
add_executable(stream ./deprecation-warning.cpp)
add_executable(command ./deprecation-warning.cpp)

View File

@ -194,7 +194,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
AVIOContext *avio_ctx;
AVStream *stream;
AVCodecContext *codec;
AVPacket *packet;
AVPacket packet;
AVFrame *frame;
struct SwrContext *swr;
u8 *avio_ctx_buffer;
@ -249,20 +249,6 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* prepare resampler */
swr = swr_alloc();
#if LIBAVCODEC_VERSION_MAJOR > 60
AVChannelLayout in_ch_layout = codec->ch_layout;
AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO;
/* Set the source audio layout as-is */
av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0);
av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
/* Convert it into 16khz Mono */
av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0);
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
#else
av_opt_set_int(swr, "in_channel_count", codec->channels, 0);
av_opt_set_int(swr, "out_channel_count", 1, 0);
av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0);
@ -271,7 +257,6 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
#endif
swr_init(swr);
if (!swr_is_initialized(swr)) {
@ -279,11 +264,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
return -1;
}
packet=av_packet_alloc();
if (!packet) {
LOG("Error allocating the packet\n");
return -1;
}
av_init_packet(&packet);
frame = av_frame_alloc();
if (!frame) {
LOG("Error allocating the frame\n");
@ -293,8 +274,8 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* iterate through frames */
*data = NULL;
*size = 0;
while (av_read_frame(fmt_ctx, packet) >= 0) {
avcodec_send_packet(codec, packet);
while (av_read_frame(fmt_ctx, &packet) >= 0) {
avcodec_send_packet(codec, &packet);
err = avcodec_receive_frame(codec, frame);
if (err == AVERROR(EAGAIN))
@ -305,11 +286,10 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* Flush any remaining conversion buffers... */
convert_frame(swr, codec, frame, data, size, true);
av_packet_free(&packet);
av_frame_free(&frame);
swr_free(&swr);
//avio_context_free(); // todo?
avcodec_free_context(&codec);
avcodec_close(codec);
avformat_close_input(&fmt_ctx);
avformat_free_context(fmt_ctx);

View File

@ -2,7 +2,7 @@
#
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
# Idea by @semiformal-net
# ref: https://github.com/ggml-org/whisper.cpp/issues/185
# ref: https://github.com/ggerganov/whisper.cpp/issues/185
#
set -eo pipefail

View File

@ -1,115 +0,0 @@
import http.server
import socketserver
import os
import sys
from pathlib import Path
import urllib.parse
SCRIPT_DIR = Path(__file__).parent.absolute()
DIRECTORY = os.path.join(SCRIPT_DIR, "../build-em/bin")
DIRECTORY = os.path.abspath(DIRECTORY)
# The context root we want for all applications
CONTEXT_ROOT = "/whisper.cpp"
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, directory=DIRECTORY, **kwargs)
def do_GET(self):
# Redirect root to the context root
if self.path == '/':
self.send_response(302)
self.send_header('Location', CONTEXT_ROOT + '/')
self.end_headers()
return
# Handle requests under the context root
if self.path.startswith(CONTEXT_ROOT):
# Remove the context root prefix to get the actual path
actual_path = self.path[len(CONTEXT_ROOT):]
if not actual_path:
self.send_response(302)
self.send_header('Location', CONTEXT_ROOT + '/')
self.end_headers()
return
if '.worker.js' in actual_path:
worker_file = os.path.basename(actual_path)
worker_path = os.path.join(DIRECTORY, worker_file)
if os.path.exists(worker_path):
print(f"Found worker file: {worker_path}")
self.path = '/' + worker_file
else:
print(f"Worker file not found: {worker_path}")
elif actual_path == '/':
self.path = '/whisper.wasm/index.html'
elif actual_path.startswith('/bench.wasm/') or actual_path.startswith('/command.wasm/') or actual_path.startswith('/stream.wasm/'):
# Keep the path as is, just remove the context root
self.path = actual_path
# For all other paths under the context root
else:
# Check if this is a request to a file in whisper.wasm
potential_file = os.path.join(DIRECTORY, 'whisper.wasm', actual_path.lstrip('/'))
if os.path.exists(potential_file) and not os.path.isdir(potential_file):
self.path = '/whisper.wasm' + actual_path
else:
# Try to resolve the file from the base directory
potential_file = os.path.join(DIRECTORY, actual_path.lstrip('/'))
if os.path.exists(potential_file):
self.path = actual_path
# For direct requests to worker files (without context root as these
# are in the build-em/bin directory
elif '.worker.js' in self.path:
worker_file = os.path.basename(self.path)
worker_path = os.path.join(DIRECTORY, worker_file)
if os.path.exists(worker_path):
self.path = '/' + worker_file
# Handle coi-serviceworker.js separately
if 'coi-serviceworker.js' in self.path:
worker_file = "coi-serviceworker.js"
worker_path = os.path.join(SCRIPT_DIR, worker_file)
if os.path.exists(worker_path):
self.send_response(200)
self.send_header('Content-type', 'application/javascript')
self.end_headers()
with open(worker_path, 'rb') as file:
self.wfile.write(file.read())
return
else:
print(f"Warning: Could not find {worker_path}")
return super().do_GET()
def end_headers(self):
# Add required headers for SharedArrayBuffer
self.send_header("Cross-Origin-Opener-Policy", "same-origin")
self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
self.send_header("Access-Control-Allow-Origin", "*")
super().end_headers()
PORT = 8000
# Enable address reuse
class CustomServer(socketserver.TCPServer):
allow_reuse_address = True
try:
with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd:
print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
print(f"Application context root: http://localhost:{PORT}{CONTEXT_ROOT}/")
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nServer stopped.")
# Force complete exit
sys.exit(0)
except OSError as e:
print(f"Error: {e}")
sys.exit(1)

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,10 @@
#include <thread>
#include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
using namespace httplib;
using json = nlohmann::ordered_json;
@ -75,7 +79,6 @@ struct whisper_params {
bool use_gpu = true;
bool flash_attn = false;
bool suppress_nst = false;
bool no_context = false;
std::string language = "en";
std::string prompt = "";
@ -137,8 +140,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
fprintf(stderr, "\n");
}
@ -185,7 +186,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
@ -506,10 +506,6 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
}
if (req.has_file("no_context"))
{
params.no_context = parse_str_to_bool(req.get_file_value("no_context").content);
}
}
} // namespace
@ -822,7 +818,6 @@ int main(int argc, char ** argv) {
wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
wparams.no_context = params.no_context;
wparams.suppress_nst = params.suppress_nst;
@ -839,25 +834,33 @@ int main(int argc, char ** argv) {
wparams.progress_callback_user_data = &user_data;
}
// tell whisper to abort if the HTTP connection closed
wparams.abort_callback = [](void *user_data) {
// user_data is a pointer to our Request
auto req_ptr = static_cast<const httplib::Request*>(user_data);
return req_ptr->is_connection_closed();
};
wparams.abort_callback_user_data = (void*)&req;
// examples for abort mechanism
// in examples below, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
// handle failure or early abort
if (req.is_connection_closed()) {
// log client disconnect
fprintf(stderr, "client disconnected, aborted processing\n");
res.status = 499; // Client Closed Request (nginx convention)
res.set_content("{\"error\":\"client disconnected\"}", "application/json");
return;
}
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
res.status = 500; // Internal Server Error
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
res.set_content(error_resp, "application/json");
return;
@ -915,26 +918,14 @@ int main(int argc, char ** argv) {
res.set_content(ss.str(), "text/vtt");
} else if (params.response_format == vjson_format) {
/* try to match openai/whisper's Python format */
std::string results = output_str(ctx, params, pcmf32s);
// Get language probabilities
std::vector<float> lang_probs(whisper_lang_max_id() + 1, 0.0f);
const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data());
std::string results = output_str(ctx, params, pcmf32s);
json jres = json{
{"task", params.translate ? "translate" : "transcribe"},
{"language", whisper_lang_str_full(whisper_full_lang_id(ctx))},
{"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE},
{"text", results},
{"segments", json::array()},
{"detected_language", whisper_lang_str_full(detected_lang_id)},
{"detected_language_probability", lang_probs[detected_lang_id]},
{"language_probabilities", json::object()}
{"segments", json::array()}
};
// Add all language probabilities
for (int i = 0; i <= whisper_lang_max_id(); ++i) {
if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities
jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i];
}
}
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i)
{
@ -1033,11 +1024,6 @@ int main(int argc, char ** argv) {
// check if the model is in the file system
});
svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json");
});
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];

View File

@ -13,17 +13,7 @@ cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
```
The example can then be started by running a local HTTP server:
```console
python3 examples/server.py
```
And then opening a browser to the following URL:
http://localhost:8000/stream.wasm
To run the example in a different server, you need to copy the following files
to the server's HTTP path:
```
# copy the produced page to your HTTP path
cp bin/stream.wasm/* /path/to/html/
cp bin/libstream.worker.js /path/to/html/

View File

@ -24,8 +24,6 @@
overflow-x: scroll;
}
</style>
<script src="../coi-serviceworker.js"></script>
<link rel="icon" href="data:,">
</head>
<body>
<div id="main-container">
@ -38,10 +36,11 @@
<br><br>
<b>More examples:</b>
<a href="../">main</a> |
<a href="../bench.wasm/">bench</a> |
<a href="../stream.wasm">stream</a> |
<a href="../command.wasm/">command</a> |
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>

View File

@ -12,12 +12,9 @@ if (WHISPER_SDL2)
llama-context.cpp
llama-cparams.cpp
llama-grammar.cpp
llama-graph.cpp
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache.cpp
llama-memory.cpp
llama-mmap.cpp
llama-model-loader.cpp
llama-model.cpp

View File

@ -4,13 +4,14 @@
#include "llama-mmap.h"
#include "llama-model.h"
#include <algorithm>
#include <map>
#include <cassert>
#include <stdexcept>
// vec
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr;
}
@ -18,7 +19,7 @@ ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
return tensors[il];
}
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir);
@ -39,7 +40,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
struct ggml_init_params params = {
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -90,7 +91,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return true;
}
bool llama_adapter_cvec::apply(
int32_t llama_adapter_cvec::apply(
const llama_model & model,
const float * data,
size_t len,
@ -103,17 +104,17 @@ bool llama_adapter_cvec::apply(
// disable the current control vector (but leave allocated for later)
layer_start = -1;
layer_end = -1;
return true;
return 0;
}
if (n_embd != (int) hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return false;
return 1;
}
if (tensors.empty()) {
if (!init(model)) {
return false;
return 1;
}
}
@ -129,12 +130,12 @@ bool llama_adapter_cvec::apply(
}
}
return true;
return 0;
}
// lora
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
const std::string name(w->name);
const auto pos = ab_map.find(name);
@ -145,11 +146,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
return nullptr;
}
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init;
gguf_init_params meta_gguf_params = {
struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true,
/* .ctx = */ &ctx_init,
};
@ -200,7 +201,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
// add a new context
ggml_init_params params = {
struct ggml_init_params params = {
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
@ -247,26 +248,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
}
// get extra buffer types of the CPU
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
std::vector<ggml_backend_buffer_type_t> buft_extra;
{
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_extra.emplace_back(*extra_bufts);
++extra_bufts;
}
}
}
// add tensors
for (auto & it : ab_map) {
const std::string & name = it.first;
@ -283,23 +264,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
}
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
for (auto & ex : buft_extra) {
if (ex == buft) {
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
buft = ggml_backend_dev_buffer_type(cpu_dev);
break;
}
}
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
ggml_context * dev_ctx = ctx_for_buft(buft);
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// validate tensor shape
if (is_token_embd) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@ -316,8 +281,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}
// save tensor to adapter
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@ -343,7 +308,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
{
llama_file gguf_file(path_lora, "rb");
std::vector<uint8_t> read_buf;
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
size_t size = ggml_nbytes(orig);
read_buf.resize(size);
@ -362,8 +327,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
}
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
llama_adapter_lora * adapter = new llama_adapter_lora();
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
struct llama_adapter_lora * adapter = new llama_adapter_lora();
try {
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@ -377,6 +342,6 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr;
}
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
delete adapter;
}

View File

@ -15,11 +15,11 @@
//
struct llama_adapter_cvec {
ggml_tensor * tensor_for(int il) const;
struct ggml_tensor * tensor_for(int il) const;
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
bool apply(
int32_t apply(
const llama_model & model,
const float * data,
size_t len,
@ -36,7 +36,7 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<ggml_tensor *> tensors; // per layer
std::vector<struct ggml_tensor *> tensors; // per layer
};
//
@ -44,8 +44,8 @@ private:
//
struct llama_adapter_lora_weight {
ggml_tensor * a = nullptr;
ggml_tensor * b = nullptr;
struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) const {
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
}
llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
};
struct llama_adapter_lora {
// map tensor name to lora_a_b
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
@ -70,7 +70,5 @@ struct llama_adapter_lora {
llama_adapter_lora() = default;
~llama_adapter_lora() = default;
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
};
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

View File

@ -6,7 +6,6 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" },
@ -19,7 +18,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
@ -27,8 +25,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
@ -40,7 +36,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@ -55,7 +50,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@ -64,14 +58,10 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -80,7 +70,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },
@ -107,7 +96,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@ -120,30 +108,23 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@ -242,35 +223,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DECI,
{
@ -474,24 +426,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{
@ -620,45 +554,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_QWEN3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_QWEN3MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_PHI2,
{
@ -871,27 +766,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@ -1125,8 +999,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@ -1143,22 +1015,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_PLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_CHATGLM,
{
@ -1177,25 +1033,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
},
{
LLM_ARCH_GLM4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_BITNET,
{
@ -1380,74 +1217,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_RWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
},
},
{
LLM_ARCH_ARWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GRANITE,
{
@ -1527,29 +1296,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
},
},
{
LLM_ARCH_BAILINGMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -1587,8 +1333,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1615,12 +1376,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1639,9 +1394,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@ -1649,9 +1401,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@ -10,7 +10,6 @@
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI,
LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN,
@ -23,7 +22,6 @@ enum llm_arch {
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
@ -31,8 +29,6 @@ enum llm_arch {
LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE,
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
@ -44,7 +40,6 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@ -59,7 +54,6 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_T5ENCODER,
@ -68,14 +62,10 @@ enum llm_arch {
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_UNKNOWN,
};
@ -84,7 +74,6 @@ enum llm_kv {
LLM_KV_GENERAL_ARCHITECTURE,
LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION,
@ -111,7 +100,6 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
@ -124,7 +112,6 @@ enum llm_kv {
LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -139,15 +126,9 @@ enum llm_kv {
LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_DECAY_LORA_RANK,
LLM_KV_ATTENTION_ICLR_LORA_RANK,
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
LLM_KV_ATTENTION_GATE_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@ -261,8 +242,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,
@ -270,20 +249,8 @@ enum llm_tensor {
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2,
LLM_TENSOR_TIME_MIX_A0,
LLM_TENSOR_TIME_MIX_A1,
LLM_TENSOR_TIME_MIX_A2,
LLM_TENSOR_TIME_MIX_V0,
LLM_TENSOR_TIME_MIX_V1,
LLM_TENSOR_TIME_MIX_V2,
LLM_TENSOR_TIME_MIX_G1,
LLM_TENSOR_TIME_MIX_G2,
LLM_TENSOR_TIME_MIX_K_K,
LLM_TENSOR_TIME_MIX_K_A,
LLM_TENSOR_TIME_MIX_R_K,
LLM_TENSOR_TIME_MIX_LERP_X,
LLM_TENSOR_TIME_MIX_LERP_W,
LLM_TENSOR_TIME_MIX_LERP_K,
@ -310,8 +277,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,

View File

@ -42,9 +42,9 @@ struct llama_sbatch {
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// sorted indices into the batch
std::vector<int64_t> ids;
std::vector<size_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<size_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;

View File

@ -4,7 +4,6 @@
#include <map>
#include <sstream>
#include <algorithm>
#if __cplusplus >= 202000L
#define LU8(x) (const char*)(u8##x)
@ -50,8 +49,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
@ -59,10 +58,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -82,9 +77,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
if (tmpl_contains("<|im_start|>")) {
return tmpl_contains("<|im_sep|>")
? LLM_CHAT_TEMPLATE_PHI_4
: tmpl_contains("<end_of_utterance>")
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
: LLM_CHAT_TEMPLATE_CHATML;
: LLM_CHAT_TEMPLATE_CHATML;
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
if (tmpl_contains("[SYSTEM_PROMPT]")) {
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
@ -122,12 +115,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
}
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
return LLM_CHAT_TEMPLATE_PHI_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGLM_4;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
return LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
return LLM_CHAT_TEMPLATE_ZEPHYR;
} else if (tmpl_contains("bos_token + message['role']")) {
@ -156,7 +145,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA_3;
} else if (tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
return LLM_CHAT_TEMPLATE_CHATGLM_3;
return LLM_CHAT_TEMPLATE_CHATGML_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGML_4;
} else if (tmpl_contains(LU8("<用户>"))) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
return LLM_CHAT_TEMPLATE_MINICPM;
@ -176,12 +167,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ;
} else if (tmpl_contains(" Ассистент:")) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -437,7 +422,7 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
@ -447,7 +432,7 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
@ -456,6 +441,14 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) {
@ -573,66 +566,6 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {
ss << " Пользователь: " << chat[i]->content << "\n\n";
} else if (role == "assistant") {
ss << " Ассистент: " << chat[i]->content << "\n\n";
}
}
// Add generation prompt if needed
if (add_ass) {
ss << " Ассистент:[SEP]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
// Bailing (Ling) template
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content;
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
// Llama 4
for (auto message : chat) {
std::string role(message->role);
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
}
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
// SmolVLM
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "<end_of_utterance>\n";
} else {
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;
@ -651,3 +584,4 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
}
return (int32_t) LLM_CHAT_TEMPLATES.size();
}

View File

@ -29,8 +29,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
LLM_CHAT_TEMPLATE_COMMAND_R,
LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGLM_3,
LLM_CHAT_TEMPLATE_CHATGLM_4,
LLM_CHAT_TEMPLATE_CHATGML_3,
LLM_CHAT_TEMPLATE_CHATGML_4,
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
@ -38,10 +38,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

File diff suppressed because it is too large Load Diff

View File

@ -3,212 +3,66 @@
#include "llama.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-model.h"
#include "llama-kv-cache.h"
#include "llama-adapter.h"
#include "ggml-cpp.h"
#include <map>
#include <unordered_map>
#include <vector>
struct llama_model;
struct llama_kv_cache;
class llama_io_read_i;
class llama_io_write_i;
#include <set>
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
llama_context(
const llama_model & model,
llama_context_params params);
llama_context(const llama_model & model)
: model(model)
, t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {}
~llama_context();
const struct llama_model & model;
void synchronize();
struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self;
struct llama_adapter_cvec cvec;
const llama_model & get_model() const;
std::unordered_map<struct llama_adapter_lora *, float> lora;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
ggml_backend_t backend_cpu = nullptr;
llama_kv_cache * get_kv_self();
const llama_kv_cache * get_kv_self() const;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
void kv_self_update();
bool has_evaluated_once = false;
enum llama_pooling_type pooling_type() const;
mutable int64_t t_start_us;
mutable int64_t t_load_us;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
float * get_logits();
float * get_logits_ith(int32_t i);
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
void attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch);
void detach_threadpool();
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale);
bool rm_adapter_lora(
llama_adapter_lora * adapter);
void clear_adapter_lora();
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
//
// state save/load
//
size_t state_get_size();
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
//
// perf
//
llama_perf_context_data perf_get_data() const;
void perf_reset();
private:
//
// output
//
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
//
// graph
//
int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf) const;
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
//
// members
//
const llama_model & model;
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self;
// TODO: remove
bool logits_all = false;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
@ -218,47 +72,57 @@ private:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
// whether we are computing encoder output or decoder output
bool is_encoding = false;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
// TODO: find a better way to accommodate mutli-dimension position encoding methods
// number of position id each token get, 1 for each token in most cases.
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
int n_pos_per_token = 1;
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
ggml_context_ptr ctx_compute;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
// buffer types used for the compute buffer of each backend
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool has_evaluated_once = false;
// perf
mutable int64_t t_start_us = 0;
mutable int64_t t_load_us = 0;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
};
// TODO: make these methods of llama_context
void llama_set_k_shift(struct llama_context & lctx);
void llama_set_s_copy(struct llama_context & lctx);
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
void llama_output_reorder(struct llama_context & ctx);
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

View File

@ -29,7 +29,6 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool no_perf;
bool warmup;
enum llama_pooling_type pooling_type;

View File

@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size();
const char * pos = src;
auto handle_repetitions = [&](int min_times, int max_times) {
auto handle_repetitions = [&](int min_times, int max_times) {
if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
}
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
break;
}
handle_repetitions(min_times, max_times);
} else {
break;
}
return pos;
}
return pos;
}
const char * llama_grammar_parser::parse_rule(const char * src) {
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len);
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
bool llama_grammar_parser::parse(const char * src) {
try {
@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "",
/* .trigger_tokens = */ {},
/* .trigger_patterns = */ {},
/* .trigger_words = */ {},
};
}
@ -978,15 +978,19 @@ struct llama_grammar * llama_grammar_init_impl(
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_patterns,
size_t num_trigger_patterns,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens) {
llama_grammar_parser parser;
// if there is a grammar, parse it
// rules will be empty (default) if there are parse errors
if (!parser.parse(grammar_str) || parser.rules.empty()) {
if (!parser.parse(grammar_str)) {
return nullptr;
}
// will be empty (default) if there are parse errors
if (parser.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr;
}
@ -1050,16 +1054,14 @@ struct llama_grammar * llama_grammar_init_impl(
} while (true);
std::vector<llama_token> vec_trigger_tokens;
std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
std::vector<std::string> vec_trigger_words;
for (size_t i = 0; i < num_trigger_tokens; i++) {
GGML_ASSERT(trigger_tokens != nullptr);
vec_trigger_tokens.push_back(trigger_tokens[i]);
}
for (size_t i = 0; i < num_trigger_patterns; i++) {
GGML_ASSERT(trigger_patterns != nullptr);
auto & trigger = vec_trigger_patterns.emplace_back();
trigger.pattern = trigger_patterns[i];
trigger.regex = std::regex(trigger.pattern);
for (size_t i = 0; i < num_trigger_words; i++) {
GGML_ASSERT(trigger_words != nullptr);
vec_trigger_words.push_back(trigger_words[i]);
}
// Important: vec_rules has to be moved here, not copied, because stacks contains
@ -1074,7 +1076,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "",
std::move(vec_trigger_tokens),
std::move(vec_trigger_patterns),
std::move(vec_trigger_words),
};
}
@ -1087,7 +1089,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
}
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
auto * result = new llama_grammar {
llama_grammar * result = new llama_grammar {
grammar.vocab,
grammar.rules,
grammar.stacks,
@ -1096,7 +1098,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.awaiting_trigger,
grammar.trigger_buffer,
grammar.trigger_tokens,
grammar.trigger_patterns,
grammar.trigger_words,
};
// redirect elements in stacks to point to new rules
@ -1171,22 +1173,20 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
return;
} else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += piece;
std::smatch match;
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
for (const auto & word : grammar.trigger_words) {
auto pos = grammar.trigger_buffer.find(word);
if (pos != std::string::npos) {
grammar.awaiting_trigger = false;
// get from the first match to the end of the string
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
auto constrained_str = grammar.trigger_buffer.substr(pos);
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
return;
}
}
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
return;
}
}

View File

@ -3,7 +3,6 @@
#include "llama.h"
#include <map>
#include <regex>
#include <string>
#include <vector>
@ -106,11 +105,6 @@ struct llama_grammar_parser {
void print(FILE * file);
};
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;
};
struct llama_grammar {
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
@ -122,16 +116,13 @@ struct llama_grammar {
llama_partial_utf8 partial_utf8;
// lazy grammars wait for trigger words or tokens before constraining the sampling.
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// (useful e.g. for tool_choice=required)
bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<llama_grammar_trigger_pattern>
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
// string, and the grammar will be given the string from the first match group onwards.
std::vector<std::string> trigger_words;
};
//
@ -150,8 +141,8 @@ struct llama_grammar * llama_grammar_init_impl(
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_patterns,
size_t num_trigger_patterns,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);

File diff suppressed because it is too large Load Diff

View File

@ -1,594 +0,0 @@
#pragma once
#include "llama-arch.h"
#include "llama-hparams.h"
#include "llama-adapter.h"
#include <cstdint>
#include <vector>
#include <memory>
#include <set>
#include <functional>
struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
class llama_memory_i;
class llama_kv_cache_unified;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
LLM_GRAPH_TYPE_ENCODER,
LLM_GRAPH_TYPE_DECODER,
};
enum llm_ffn_op_type {
LLM_FFN_SILU,
LLM_FFN_GELU,
LLM_FFN_RELU,
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
};
enum llm_ffn_gate_type {
LLM_FFN_SEQ,
LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
};
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
LLM_NORM_GROUP,
};
// TODO: tmp - need something better to pass the data from the encoder to the decoder
struct llama_cross {
// the output embeddings from the encoder as a ggml tensor
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
//ggml_tensor * t_embd = nullptr;
int64_t n_embd = 0;
int64_t n_enc = 0;
// embeddings data copied to host memory (tmp)
std::vector<float> v_embd;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
//
// llm_graph_input
//
class llm_graph_input_i {
public:
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd() = default;
virtual ~llm_graph_input_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const int64_t n_pos_per_embd = 1;
};
// temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i {
public:
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale;
};
class llm_graph_input_pos_bucket : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
virtual ~llm_graph_input_pos_bucket() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
const llama_hparams & hparams;
};
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
public:
llm_graph_input_out_ids(
const llama_hparams & hparams,
const llama_cparams & cparams,
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
virtual ~llm_graph_input_out_ids() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const int32_t n_outputs;
};
class llm_graph_input_mean : public llm_graph_input_i {
public:
llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_mean() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * mean; // F32 [n_batch, n_batch]
const llama_cparams & cparams;
};
class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cls; // I32 [n_batch]
const llama_cparams & cparams;
};
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
const llama_cross * cross) : cross(cross) {}
virtual ~llm_graph_input_cross_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
const llama_cross * cross;
};
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
public:
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
hparams(hparams),
cparams(cparams) {
}
~llm_graph_input_attn_no_cache() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
};
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified * kv_self) :
hparams(hparams),
cparams(cparams),
kv_self(kv_self) {
}
~llm_graph_input_attn_kv_unified() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
~llm_graph_input_attn_cross() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
const llama_cross * cross = nullptr;
};
//
// llm_graph_result
//
// these objects deliver the result from the graph build process back to the llama_context
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
// specific data, by calling the set_inputs() method
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
// important graph nodes
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
};
//
// llm_graph_context
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
struct llm_graph_params {
ggml_context * ctx;
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
ggml_backend_sched * sched;
ggml_backend * backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
int32_t n_outputs;
const llm_graph_cb & cb;
};
struct llm_graph_context {
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_ctx_per_seq;
const int64_t n_head;
const int64_t n_head_kv;
const int64_t n_embd_head_k;
const int64_t n_embd_k_gqa;
const int64_t n_embd_head_v;
const int64_t n_embd_v_gqa;
const int64_t n_expert;
const int64_t n_expert_used;
const float freq_base;
const float freq_scale;
const float ext_factor;
const float attn_factor;
const float beta_fast;
const float beta_slow;
const float norm_eps;
const float norm_rms_eps;
const int32_t n_tokens;
const int32_t n_outputs;
const int32_t n_ctx_orig; // yarn
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
ggml_context * ctx0 = nullptr;
ggml_backend_sched * sched;
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
const llm_graph_cb & cb_func;
std::unique_ptr<llm_graph_result> res;
llm_graph_context(const llm_graph_params & params);
int64_t n_pos_per_embd() const;
void cb(ggml_tensor * cur, const char * name, int il) const;
//
// common
//
ggml_tensor * build_cvec(
ggml_tensor * cur,
int il) const;
// do mat_mul, while optionally apply lora
ggml_tensor * build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const;
// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
ggml_tensor * cur, // ggml_tensor * b
ggml_tensor * ids) const;
ggml_tensor * build_norm(
ggml_tensor * cur,
ggml_tensor * mw,
ggml_tensor * mb,
llm_norm_type type,
int il) const;
ggml_tensor * build_ffn(
ggml_tensor * cur,
ggml_tensor * up,
ggml_tensor * up_b,
ggml_tensor * up_s,
ggml_tensor * gate,
ggml_tensor * gate_b,
ggml_tensor * gate_s,
ggml_tensor * down,
ggml_tensor * down_b,
ggml_tensor * down_s,
ggml_tensor * act_scales,
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
int il) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
ggml_tensor * gate_inp,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
bool norm_w,
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
int il) const;
//
// inputs
//
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
ggml_tensor * build_inp_pos() const;
ggml_tensor * build_inp_attn_scale() const;
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
//
// attention
//
ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
bool v_trans,
float kq_scale) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
//
// recurrent
//
ggml_tensor * build_copy_mask_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const;
//
// pooling
//
void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
};

View File

@ -69,11 +69,3 @@ uint32_t llama_hparams::n_embd_v_s() const {
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
}
GGML_ABORT("fatal error");
}

View File

@ -36,17 +36,12 @@ struct llama_hparams {
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
uint32_t n_expert_used = 0;
uint32_t n_rel_attn_bkts = 0;
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla = 0;
uint32_t n_embd_head_v_mla = 0;
// for WavTokenizer
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
@ -66,7 +61,6 @@ struct llama_hparams {
float expert_weights_scale = 0.0;
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
float f_norm_eps;
float f_norm_rms_eps;
@ -81,16 +75,10 @@ struct llama_hparams {
uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0;
uint32_t token_shift_count = 2;
uint32_t n_lora_decay = 0;
uint32_t n_lora_iclr = 0;
uint32_t n_lora_value_res_mix = 0;
uint32_t n_lora_gate = 0;
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_base_train_swa;
float rope_freq_scale_train;
float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
@ -117,14 +105,6 @@ struct llama_hparams {
bool use_alibi = false;
bool attn_soft_cap = false;
uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@ -153,8 +133,6 @@ struct llama_hparams {
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
bool is_swa(uint32_t il) const;
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

Some files were not shown because too many files have changed in this diff Show More