Compare commits

..

1 Commits

Author SHA1 Message Date
05ce7476ae ggml-ci: update input env variables to GG_BUILD_ 2025-03-14 03:14:44 -05:00
361 changed files with 40394 additions and 60877 deletions

View File

@ -13,6 +13,8 @@ WORKDIR /app
ARG CUDA_DOCKER_ARCH=all ARG CUDA_DOCKER_ARCH=all
# Set nvcc architecture # Set nvcc architecture
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV GGML_CUDA=1
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \ apt-get install -y build-essential libsdl2-dev wget cmake git \
@ -23,8 +25,7 @@ ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY .. . COPY .. .
# Enable cuBLAS RUN make base.en
RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1"
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
ENV CUDA_MAIN_VERSION=12.3 ENV CUDA_MAIN_VERSION=12.3
@ -36,5 +37,4 @@ RUN apt-get update && \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ] ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,29 +0,0 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc3.1.1
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the MUSA runtime image
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
# Enable muBLAS
RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1"
FROM ${BASE_MUSA_RUN_CONTAINER} AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -16,5 +16,4 @@ RUN apt-get update && \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ] ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,3 +0,0 @@
build*/
.github/
.devops/

View File

@ -1,11 +1,55 @@
name: Bindings Tests (Ruby) name: Bindings Tests (Ruby)
on: on:
push: push:
branches: paths:
- master - bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/common.h
- examples/common.cpp
- examples/common-whisper.h
- examples/common-whisper.cpp
- examples/stb_vorbis.c
- examples/miniaudio.h
pull_request: pull_request:
types: [opened, synchronize, reopened] paths:
- bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/common.h
- examples/common.cpp
- examples/common-whisper.h
- examples/common-whisper.cpp
- examples/stb_vorbis.c
- examples/miniaudio.h
jobs: jobs:
ubuntu-22: ubuntu-22:
@ -16,6 +60,6 @@ jobs:
steps: steps:
- uses: ruby/setup-ruby@v1 - uses: ruby/setup-ruby@v1
with: with:
ruby-version: '3.2' ruby-version: '3.1'
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- run: rake test - run: rake test

View File

@ -6,81 +6,17 @@ on:
- master - master
pull_request: pull_request:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
workflow_dispatch:
inputs:
create_release:
description: 'Create new release'
required: true
type: boolean
pre_release_tag:
description: 'Pre-release tag name'
required: false
type: string
run_type:
description: 'Workflow type to run'
required: true
type: choice
options:
- full-ci
- release-only
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
permissions:
contents: write # for creating release
env: env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
ubuntu_image: "ubuntu:22.04" ubuntu_image: "ubuntu:22.04"
VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite"
jobs: jobs:
determine-tag:
runs-on: ubuntu-latest
outputs:
tag_name: ${{ steps.tag.outputs.name }}
steps:
- name: Checkout with full history
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER=$(git rev-list --count HEAD)
SHORT_HASH=$(git rev-parse --short=7 HEAD)
CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}"
echo "Raw values:"
echo "BUILD_NUMBER: $BUILD_NUMBER"
echo "SHORT_HASH: $SHORT_HASH"
echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}"
echo "CUSTOM_TAG: $CUSTOM_TAG"
# Use custom tag if provided
if [[ -n "$CUSTOM_TAG" ]]; then
echo "Using custom tag"
TAG_NAME="${CUSTOM_TAG}"
elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "Using master branch format"
TAG_NAME="b${BUILD_NUMBER}"
else
echo "Using non-master branch format"
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}"
fi
echo "Final tag name: $TAG_NAME"
echo "name=$TAG_NAME" >> $GITHUB_OUTPUT
ubuntu-22: ubuntu-22:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -107,8 +43,6 @@ jobs:
cmake --build build --config Release -j $(nproc)' cmake --build build --config Release -j $(nproc)'
ubuntu-22-arm64: ubuntu-22-arm64:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -135,8 +69,6 @@ jobs:
cmake --build build --config Release -j $(nproc)' cmake --build build --config Release -j $(nproc)'
ubuntu-22-arm-v7: ubuntu-22-arm-v7:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -163,8 +95,6 @@ jobs:
cmake --build build --config Release -j $(nproc)' cmake --build build --config Release -j $(nproc)'
macOS-latest: macOS-latest:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: macOS-latest runs-on: macOS-latest
strategy: strategy:
@ -199,28 +129,31 @@ jobs:
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
- name: xcodebuild for swift package
id: xcodebuild
run: |
./build-xcframework.sh
# freeBSD-latest: # freeBSD-latest:
# runs-on: macos-13 # runs-on: macos-12
# #
# steps: # steps:
# - name: Clone # - name: Clone
# uses: actions/checkout@v4 # uses: actions/checkout@v4
# #
# - name: Build # - name: Build
# uses: cross-platform-actions/action@v0.27.0 # uses: cross-platform-actions/action@v0.24.0
# with: # with:
# operating_system: freebsd # operating_system: freebsd
# version: '14.2' # version: '13.3'
# run: | # run: |
# sudo pkg update # sudo pkg update
# sudo pkg install -y gmake sdl2 cmake git # sudo pkg install -y gmake sdl2 cmake
# cmake -B build # cmake -B build
# cmake --build build --config Release # cmake --build build --config Release
ubuntu-22-gcc: ubuntu-22-gcc:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -249,8 +182,6 @@ jobs:
ctest -L gh --output-on-failure' ctest -L gh --output-on-failure'
ubuntu-22-gcc-arm64: ubuntu-22-gcc-arm64:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -279,8 +210,6 @@ jobs:
ctest -L gh --output-on-failure' ctest -L gh --output-on-failure'
ubuntu-22-gcc-arm-v7: ubuntu-22-gcc-arm-v7:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -309,8 +238,6 @@ jobs:
ctest -L gh --output-on-failure' ctest -L gh --output-on-failure'
ubuntu-22-clang: ubuntu-22-clang:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -342,8 +269,6 @@ jobs:
ctest -L gh --output-on-failure' ctest -L gh --output-on-failure'
ubuntu-22-gcc-sanitized: ubuntu-22-gcc-sanitized:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -367,15 +292,11 @@ jobs:
set -e set -e
apt update apt update
apt install -y build-essential cmake git apt install -y build-essential cmake git
cmake . -DCMAKE_BUILD_TYPE=Debug \ cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
-DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_OPENMP=OFF
make make
ctest -L gh --output-on-failure' ctest -L gh --output-on-failure'
ubuntu-22-cmake-sycl: ubuntu-22-cmake-sycl:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -426,8 +347,6 @@ jobs:
cmake --build . --config Release -j $(nproc) cmake --build . --config Release -j $(nproc)
ubuntu-22-cmake-sycl-fp16: ubuntu-22-cmake-sycl-fp16:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -478,8 +397,6 @@ jobs:
cmake --build . --config Release -j $(nproc) cmake --build . --config Release -j $(nproc)
windows-msys2: windows-msys2:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest runs-on: windows-latest
strategy: strategy:
@ -524,8 +441,6 @@ jobs:
cmake --build build --config ${{ matrix.build }} -j $(nproc) cmake --build build --config ${{ matrix.build }} -j $(nproc)
windows: windows:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest runs-on: windows-latest
strategy: strategy:
@ -561,7 +476,6 @@ jobs:
run: > run: >
cmake -S . -B ./build -A ${{ matrix.arch }} cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DBUILD_SHARED_LIBS=ON
-DWHISPER_SDL2=${{ matrix.sdl2 }} -DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build - name: Build
@ -573,37 +487,12 @@ jobs:
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload SDL2.dll - name: Upload dll
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.s2arc }}_SDL2.dll name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/SDL2.dll
- name: Upload whisper dll
uses: actions/upload-artifact@v4
with:
name: whisper_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/whisper.dll path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload ggml dll
uses: actions/upload-artifact@v4
with:
name: ggml_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml.dll
- name: Upload ggml base dll
uses: actions/upload-artifact@v4
with:
name: ggml_base_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml-base.dll
- name: Upload ggml cpu dll
uses: actions/upload-artifact@v4
with:
name: ggml_cpu_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml-cpu.dll
- name: Upload binaries - name: Upload binaries
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
@ -612,8 +501,6 @@ jobs:
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
windows-blas: windows-blas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest runs-on: windows-latest
strategy: strategy:
@ -687,8 +574,6 @@ jobs:
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
windows-cublas: windows-cublas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-2019 runs-on: windows-2019
strategy: strategy:
matrix: matrix:
@ -705,134 +590,15 @@ jobs:
- name: Clone repository - name: Clone repository
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install Ninja
id: install_ninja
run: |
choco install ninja
- name: Install ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }}
variant: sccache
evict-old-files: 5d
- name: Install Cuda Toolkit 11.8.0
if: ${{ matrix.cuda-toolkit == '11.8.0' }}
run: |
$CUDA_VERSION = ${{ matrix.cuda-toolkit }}
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
# Components versions
$CUDART_VER = "11.8.89"
$NVCC_VER = "11.8.89"
$NVRTC_VER = "11.8.89"
$CUBLAS_VER = "11.8.1.74"
$NVTX_VER = "11.8.86"
$VS_VER = "11.8.86"
$NVPROF_VER = "11.8.87"
$CCCL_VER = "11.8.89"
# Create the directory where the CUDA Toolkit will be installed
mkdir -p $CUDA_TOOLKIT_DIR
# Install unzip to extract the downloaded files
choco install unzip -y
# Download all the required components
curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip"
# Extract all the downloaded files to the CUDA Toolkit directory
unzip '*.zip' -d $CUDA_TOOLKIT_DIR
# Copy all the extracted files to the main CUDA Toolkit directory
xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Install Cuda Toolkit 12.2.0
if: ${{ matrix.cuda-toolkit == '12.2.0' }}
run: |
$CUDA_VERSION = ${{ matrix.cuda-toolkit }}
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
# Components versions
$CUDART_VER = "12.2.140"
$NVCC_VER = "12.2.140"
$NVRTC_VER = "12.2.140"
$CUBLAS_VER = "12.2.5.6"
$NVTX_VER = "12.2.140"
$PROFILER_VER = "12.2.140"
$VS_VER = "12.2.140"
$NVPROF_VER = "12.2.142"
$CCCL_VER = "12.2.140"
# Create the directory where the CUDA Toolkit will be installed
mkdir -p $CUDA_TOOLKIT_DIR
# Install unzip to extract the downloaded files
choco install unzip -y
# Download all the required components
curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip"
curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip"
# Extract all the downloaded files to the CUDA Toolkit directory
unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR
# Copy all the extracted files to the main CUDA Toolkit directory
xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Add msbuild to PATH - name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2 uses: microsoft/setup-msbuild@v2
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.15
with:
cuda: '${{ matrix.cuda-toolkit }}'
- name: Install 7-Zip - name: Install 7-Zip
run: choco install 7zip -y run: choco install 7zip -y
@ -844,30 +610,25 @@ jobs:
echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt
- name: Install cmake - name: Configure CMake
run: choco install cmake shell: cmd
run: |
cmake -S . -B ./build -A ${{ matrix.arch }} ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^
-DCMAKE_CUDA_ARCHITECTURES=all ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%"
- name: Build Project - name: Build Project
shell: cmd shell: cmd
run: | run: |
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" cd ./build
cmake --version cmake --build . --config ${{ matrix.build }}
where cmake
cmake -S . -B build -G "Ninja Multi-Config" ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%"
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%
- name: Check sccache status after build
run: |
sccache --show-stats
- name: Copy CUDA DLLs - name: Copy CUDA DLLs
run: | run: |
Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | Get-ChildItem "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/" -Filter "*.dll" |
Copy-Item -Destination "build/bin/${{ matrix.build }}" Copy-Item -Destination "build/bin/${{ matrix.build }}"
- name: Copy SDL2.dll - name: Copy SDL2.dll
@ -881,8 +642,6 @@ jobs:
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
emscripten: emscripten:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
strategy: strategy:
@ -906,7 +665,6 @@ jobs:
ios-xcode-build: ios-xcode-build:
runs-on: macos-latest runs-on: macos-latest
needs: determine-tag
strategy: strategy:
matrix: matrix:
@ -949,26 +707,7 @@ jobs:
- name: Build swiftui example - name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
- name: Pack artifacts
id: pack_artifacts
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
run: |
zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework
- name: Upload artifacts
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4
with:
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
android: android:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@ -997,113 +736,64 @@ jobs:
cd whisper/examples/whisper.android cd whisper/examples/whisper.android
./gradlew assembleRelease --no-daemon ./gradlew assembleRelease --no-daemon
android_java: # TODO: disable because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/11019444420/job/30627193602
runs-on: ubuntu-22.04 # android_java:
# runs-on: ubuntu-22.04
steps: #
- name: Clone # steps:
uses: actions/checkout@v4 # - name: Clone
# uses: actions/checkout@v4
- name: set up JDK 11 #
uses: actions/setup-java@v4 # - name: set up JDK 11
with: # uses: actions/setup-java@v4
java-version: '11' # with:
distribution: 'temurin' # java-version: '11'
cache: gradle # distribution: 'temurin'
# cache: gradle
- name: Setup Android SDK #
uses: android-actions/setup-android@v3 # - name: Setup Android SDK
with: # uses: android-actions/setup-android@v3
cmdline-tools-version: 9.0 # with:
# cmdline-tools-version: 9.0
- name: Build #
run: | # - name: Build
cd examples/whisper.android.java # run: |
chmod +x ./gradlew # cd examples/whisper.android.java
./gradlew assembleRelease # chmod +x ./gradlew
# ./gradlew assembleRelease
bindings-java:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
needs: ['windows']
runs-on: windows-latest
steps:
- uses: actions/checkout@v4
- name: Install Java
uses: actions/setup-java@v4
with:
distribution: zulu
java-version: 20
- name: Download Whisper Windows lib
uses: actions/download-artifact@v4
with:
name: whisper_x64.dll
- name: Download GGML Windows lib
uses: actions/download-artifact@v4
with:
name: ggml_x64.dll
- name: Download GGML Base Windows lib
uses: actions/download-artifact@v4
with:
name: ggml_base_x64.dll
- name: Download GGML CPU Windows lib
uses: actions/download-artifact@v4
with:
name: ggml_cpu_x64.dll
- name: Download SDL2.dll
uses: actions/download-artifact@v4
with:
name: x64_SDL2.dll
- name: List downloaded files
shell: pwsh
run: |
Get-ChildItem -Path "." -Recurse -Filter "*.dll"
- name: Move DLL to correct location
shell: pwsh
run: |
New-Item -Path "build\bin\Release" -ItemType Directory -Force
Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force
Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory"
Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force
Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory"
Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force
Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory"
Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force
Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory"
Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force
Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory"
- name: List build release files
shell: pwsh
run: |
Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll"
- name: Build
run: |
models\download-ggml-model.cmd tiny.en models/
cd bindings/java
chmod +x ./gradlew
./gradlew build --info
- name: Upload jar
uses: actions/upload-artifact@v4
with:
name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar
# TODO: disabled because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/9686220096/job/26735899598
# java:
# needs: [ 'windows' ]
# runs-on: windows-latest
# steps:
# - uses: actions/checkout@v4
#
# - name: Install Java
# uses: actions/setup-java@v4
# with:
# distribution: zulu
# java-version: 20
#
# - name: Download Windows lib
# uses: actions/download-artifact@v4
# with:
# name: win32-x86-64_whisper.dll
# path: bindings/java/build/generated/resources/main/win32-x86-64
#
# - name: Build
# run: |
# models\download-ggml-model.cmd tiny.en
# cd bindings/java
# chmod +x ./gradlew
# ./gradlew build
#
# - name: Upload jar
# uses: actions/upload-artifact@v4
# with:
# name: whispercpp.jar
# path: bindings/java/build/libs/whispercpp-*.jar
#
# - name: Publish package # - name: Publish package
# if: ${{ github.ref == 'refs/heads/master' }} # if: ${{ github.ref == 'refs/heads/master' }}
# uses: gradle/gradle-build-action@v2.4.2 # uses: gradle/gradle-build-action@v2.4.2
@ -1117,8 +807,6 @@ jobs:
# PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} # PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
quantize: quantize:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@ -1131,95 +819,3 @@ jobs:
cmake -B build cmake -B build
cmake --build build --config Release cmake --build build --config Release
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 ./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
release:
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' }}
runs-on: ubuntu-latest
needs:
- determine-tag
- ios-xcode-build
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: release
evict-old-files: 1d
# Downloads all the artifacts from the previous jobs
- name: Download artifacts
id: download-artifact
uses: actions/download-artifact@v4
with:
path: ./artifact
- name: Move artifacts
id: move_artifacts
run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release
- name: Create release
id: create_release
uses: ggml-org/action-create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ needs.determine-tag.outputs.tag_name }}
prerelease: ${{ github.event.inputs.pre_release_tag != '' }}
- name: Upload release
id: upload_release
uses: actions/github-script@v3
with:
github-token: ${{secrets.GITHUB_TOKEN}}
script: |
const path = require('path');
const fs = require('fs');
const release_id = '${{ steps.create_release.outputs.id }}';
for (let file of await fs.readdirSync('./artifact/release')) {
if (path.extname(file) === '.zip') {
console.log('uploadReleaseAsset', file);
await github.repos.uploadReleaseAsset({
owner: context.repo.owner,
repo: context.repo.repo,
release_id: release_id,
name: file,
data: await fs.readFileSync(`./artifact/release/${file}`)
});
}
}
coreml-base-en:
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
runs-on: macos-latest
needs: determine-tag
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set environment variables
id: set_vars
run: |
echo "MODEL_NAME=base.en" >> $GITHUB_ENV
echo "GEN_MODEL_NAME=whisper-${{ needs.determine-tag.outputs.tag_name }}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV
- name: Download model
run: |
./models/download-ggml-model.sh ${{ env.MODEL_NAME }}
- name: Generate CoreML model
run: |
python3.11 -m venv venv
source venv/bin/activate
pip install ane_transformers openai-whisper coremltools
./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}

View File

@ -18,7 +18,6 @@ jobs:
matrix: matrix:
config: config:
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" } - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
- { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" }
#TODO: the cuda image keeps failing - disable for now #TODO: the cuda image keeps failing - disable for now
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339 # https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } #- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }

View File

@ -1,91 +0,0 @@
name: Examples WASM
on:
push:
branches: ["master"]
workflow_dispatch:
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: "pages"
cancel-in-progress: false
jobs:
deploy-wasm-github-pages:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v4
- name: Setup emsdk
uses: mymindstorm/setup-emsdk@v14
- name: Build WASM Examples
# Enable for real build later in whisper.cpp
run: |
mkdir -p build-em && cd build-em
emcmake cmake .. -DCMAKE_BUILD_TYPE=Release
make -j
- name: Create staging directory
run: mkdir -p staging
- name: Create .nojekyll file in staging directory
run: touch staging/.nojekyll
- name: Copy application files
run: |
build_dir=build-em/bin
ls ${build_dir}
# command.wasm
target_dir=staging/command.wasm
mkdir -p ${target_dir}
cp ${build_dir}/command.wasm/{index.html,command.js,helpers.js} ${target_dir}
cp ${build_dir}/libcommand.js ${target_dir}
# bench.wasm
target_dir=staging/bench.wasm
mkdir -p ${target_dir}
cp ${build_dir}/bench.wasm/{index.html,bench.js,helpers.js} ${target_dir}
cp ${build_dir}/libbench.js ${target_dir}
# stream.wasm
target_dir=staging/stream.wasm
mkdir -p ${target_dir}
cp ${build_dir}/stream.wasm/{index.html,stream.js,helpers.js} ${target_dir}
cp ${build_dir}/libstream.js ${target_dir}
# whisper.wasm (this will be the main example page)
target_dir=staging
mkdir -p ${target_dir}
cp ${build_dir}/whisper.wasm/{index.html,main.js,helpers.js} ${target_dir}
cp ${build_dir}/libmain.js ${target_dir}
# Copy Cross-Origin Isolation service worker
cp -v examples/coi-serviceworker.js staging/
- name: List files in staging directory (for debugging)
run: |
echo "Files in staging directory:"
find staging -type f | sort
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./staging
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

0
.gitmodules vendored Normal file
View File

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX) project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.7.5) project("whisper.cpp" VERSION 1.7.4)
include(CheckIncludeFileCXX) include(CheckIncludeFileCXX)
set(SOVERSION 1) set(SOVERSION 1)
@ -38,13 +38,8 @@ if (EMSCRIPTEN)
# TODO: without these, we get the following error: # TODO: without these, we get the following error:
# wasm-ld: error: --shared-memory is disallowed by whisper.cpp.o because it was not compiled with 'atomics' or 'bulk-memory' features. # wasm-ld: error: --shared-memory is disallowed by whisper.cpp.o because it was not compiled with 'atomics' or 'bulk-memory' features.
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -s TOTAL_STACK=5242880")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -s TOTAL_STACK=5242880")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -s TOTAL_STACK=5242880")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s TOTAL_STACK=5242880")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated")
else() else()
if (MINGW) if (MINGW)
set(BUILD_SHARED_LIBS_DEFAULT OFF) set(BUILD_SHARED_LIBS_DEFAULT OFF)
@ -68,7 +63,6 @@ option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in
# build # build
option(WHISPER_FATAL_WARNINGS "whisper: enable -Werror flag" OFF) option(WHISPER_FATAL_WARNINGS "whisper: enable -Werror flag" OFF)
option(WHISPER_USE_SYSTEM_GGML "whisper: use system-installed GGML library" OFF)
# sanitizers # sanitizers
option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF) option(WHISPER_SANITIZE_THREAD "whisper: enable thread sanitizer" OFF)
@ -127,31 +121,7 @@ whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16)
# #
if (NOT TARGET ggml) if (NOT TARGET ggml)
if (WHISPER_USE_SYSTEM_GGML)
find_package(ggml REQUIRED)
if (NOT ggml_FOUND)
message(FATAL_ERROR "System-installed GGML library not found.")
endif()
add_library(ggml ALIAS ggml::ggml)
else()
add_subdirectory(ggml) add_subdirectory(ggml)
if(WIN32)
# The following adds a _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR macro and is a workaround for
# the Windows C++ standard library which does not support constexpr mutexes.
# From the release notes://github.com/microsoft/STL/wiki/Changelog
# Disable constexpr mutex constructor on Windows
# Fixed mutex's constructor to be constexpr. #3824 #4000 #4339
# Note: Programs that aren't following the documented restrictions on binary compatibility may encounter
# null dereferences in mutex machinery. You must follow this rule:
# When you mix binaries built by different supported versions of the toolset, the Redistributable version
# must be at least as new as the latest toolset used by any app component.
# You can define _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR as an escape hatch.
#
# Specifically to whisper.cpp this would cause a crash when using the Java bindings.
# resulting in a Invalid memory access error.
target_compile_definitions(ggml-base PRIVATE _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()
endif()
# ... otherwise assume ggml is added by a parent CMakeLists.txt # ... otherwise assume ggml is added by a parent CMakeLists.txt
endif() endif()
add_subdirectory(src) add_subdirectory(src)
@ -206,43 +176,10 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc"
# #
if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION) if (WHISPER_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
include(CTest) #include(CTest)
add_subdirectory(tests) #add_subdirectory(tests)
endif () endif ()
if (WHISPER_BUILD_EXAMPLES) if (WHISPER_BUILD_EXAMPLES)
add_subdirectory(examples) add_subdirectory(examples)
endif() endif()
if (MSVC)
set(MSVC_WARNING_FLAGS
/wd4101 # Unreferenced local variable
/wd4005 # Macro redefinition
/wd4065 # switch statement contains 'default' but no 'case' labels
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
/wd4244 # Conversion from one type to another type, possible loss of ata
/wd4805 # Unsafe mix of type
/wd4305 # Truncation from 'type1' to 'type2' (often double to float)
/wd4996 # Function or variable may be unsafe/deprecated
)
function(disable_msvc_warnings target_name)
if(TARGET ${target_name})
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
endif()
endfunction()
if (WHISPER_BUILD_EXAMPLES)
disable_msvc_warnings(whisper)
disable_msvc_warnings(common)
disable_msvc_warnings(common-sdl)
disable_msvc_warnings(lsp)
disable_msvc_warnings(wchess-core)
disable_msvc_warnings(whisper-command)
disable_msvc_warnings(whisper-cli)
disable_msvc_warnings(whisper-server)
disable_msvc_warnings(whisper-stream)
disable_msvc_warnings(whisper-talk-llama)
disable_msvc_warnings(whisper-bench)
disable_msvc_warnings(quantize)
endif()
endif()

View File

@ -4,7 +4,7 @@
.PHONY: build .PHONY: build
build: build:
cmake -B build $(CMAKE_ARGS) cmake -B build
cmake --build build --config Release cmake --build build --config Release
# download a few audio samples into folder "./samples": # download a few audio samples into folder "./samples":
@ -41,14 +41,14 @@ samples:
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo: tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo:
bash ./models/download-ggml-model.sh $@ bash ./models/download-ggml-model.sh $@
cmake -B build $(CMAKE_ARGS) cmake -B build
cmake --build build --config Release cmake --build build --config Release
@echo "" @echo ""
@echo "===============================================" @echo "==============================================="
@echo "Running $@ on all samples in ./samples ..." @echo "Running $@ on all samples in ./samples ..."
@echo "===============================================" @echo "==============================================="
@echo "" @echo ""
@for f in samples/*.{flac,mp3,ogg,wav}; do \ @for f in samples/*$(.flac .mp3 .ogg .wav); do \
echo "----------------------------------------------" ; \ echo "----------------------------------------------" ; \
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \ echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
echo "----------------------------------------------" ; \ echo "----------------------------------------------" ; \

161
README.md
View File

@ -2,12 +2,15 @@
![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg) ![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
[![Actions Status](https://github.com/ggml-org/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggml-org/whisper.cpp/actions) [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.7.5](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.5) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) > [!NOTE]
> New maintenance roadmap: https://github.com/ggerganov/whisper.cpp/discussions/2788
Stable: [v1.7.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -23,8 +26,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- [Efficient GPU support for NVIDIA](#nvidia-gpu-support) - [Efficient GPU support for NVIDIA](#nvidia-gpu-support)
- [OpenVINO Support](#openvino-support) - [OpenVINO Support](#openvino-support)
- [Ascend NPU Support](#ascend-npu-support) - [Ascend NPU Support](#ascend-npu-support)
- [Moore Threads GPU Support](#moore-threads-gpu-support) - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/include/whisper.h)
- [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
Supported platforms: Supported platforms:
@ -32,14 +34,14 @@ Supported platforms:
- [x] [iOS](examples/whisper.objc) - [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android) - [x] [Android](examples/whisper.android)
- [x] [Java](bindings/java/README.md) - [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggml-org/whisper.cpp/issues/56#issuecomment-1350920264) - [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm) - [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168)] - [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggml-org/whisper.cpp/discussions/166) - [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
- [x] [Docker](https://github.com/ggml-org/whisper.cpp/pkgs/container/whisper.cpp) - [x] [Docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp). The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp).
The rest of the code is part of the [`ggml`](https://github.com/ggml-org/ggml) machine learning library. The rest of the code is part of the [`ggml`](https://github.com/ggerganov/ggml) machine learning library.
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications. Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc) As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
@ -52,14 +54,14 @@ https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a
On Apple Silicon, the inference runs fully on the GPU via Metal: On Apple Silicon, the inference runs fully on the GPU via Metal:
https://github.com/ggml-org/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225 https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
## Quick start ## Quick start
First clone the repository: First clone the repository:
```bash ```bash
git clone https://github.com/ggml-org/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
``` ```
Navigate into the directory: Navigate into the directory:
@ -150,7 +152,6 @@ standard cmake setup with:
cmake -B build -DGGML_BLAS=1 cmake -B build -DGGML_BLAS=1
cmake --build build --config Release cmake --build build --config Release
./build/bin/whisper-cli [ .. etc .. ] ./build/bin/whisper-cli [ .. etc .. ]
```
## Quantization ## Quantization
@ -183,11 +184,11 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
``` ```
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools. - To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
- Python 3.11 is recommended. - Python 3.10 is recommended.
- MacOS Sonoma (version 14) or newer is recommended, as older versions of MacOS might experience issues with transcription hallucination. - MacOS Sonoma (version 14) or newer is recommended, as older versions of MacOS might experience issues with transcription hallucination.
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step: - [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
- To create an environment, use: `conda create -n py311-whisper python=3.11 -y` - To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
- To activate the environment, use: `conda activate py311-whisper` - To activate the environment, use: `conda activate py310-whisper`
- Generate a Core ML model. For example, to generate a `base.en` model, use: - Generate a Core ML model. For example, to generate a `base.en` model, use:
@ -224,7 +225,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format. The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster. Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggml-org/whisper.cpp/pull/566). For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## OpenVINO support ## OpenVINO support
@ -309,7 +310,7 @@ This can result in significant speedup in encoder performance. Here are the inst
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run. cached for the next run.
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggml-org/whisper.cpp/pull/1037). For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support ## NVIDIA GPU support
@ -323,12 +324,6 @@ cmake -B build -DGGML_CUDA=1
cmake --build build -j --config Release cmake --build build -j --config Release
``` ```
or for newer NVIDIA GPU's (RTX 5000 series):
```
cmake -B build -DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"
cmake --build build -j --config Release
```
## Vulkan GPU support ## Vulkan GPU support
Cross-vendor solution which allows you to accelerate workload on your GPU. Cross-vendor solution which allows you to accelerate workload on your GPU.
First, make sure your graphics card driver provides support for Vulkan API. First, make sure your graphics card driver provides support for Vulkan API.
@ -382,56 +377,6 @@ Run the inference examples as usual, for example:
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag. - If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`. - If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
## Moore Threads GPU support
With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels.
First, make sure you have installed `MUSA SDK rc3.1.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=rc3.1.1
Now build `whisper.cpp` with MUSA support:
```
cmake -B build -DGGML_MUSA=1
cmake --build build -j --config Release
```
or specify the architecture for your Moore Threads GPU. For example, if you have a MTT S80 GPU, you can specify the architecture as follows:
```
cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21"
cmake --build build -j --config Release
```
## FFmpeg support (Linux only)
If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration.
First, you need to install required libraries:
```bash
# Debian/Ubuntu
sudo apt install libavcodec-dev libavformat-dev libavutil-dev
# RHEL/Fedora
sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-devel
```
Then you can build the project as follows:
```bash
cmake -B build -D WHISPER_FFMPEG=yes
cmake --build build
```
Run the following example to confirm it's working:
```bash
# Convert an audio file to Opus format
ffmpeg -i samples/jfk.wav jfk.opus
# Transcribe the audio file
./build/bin/whisper-cli --model models/ggml-base.en.bin --file jfk.opus
```
## Docker ## Docker
### Prerequisites ### Prerequisites
@ -443,9 +388,8 @@ ffmpeg -i samples/jfk.wav jfk.opus
We have two Docker images available for this project: We have two Docker images available for this project:
1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`) 1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`) 2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`)
### Usage ### Usage
@ -458,11 +402,11 @@ docker run -it --rm \
docker run -it --rm \ docker run -it --rm \
-v path/to/models:/models \ -v path/to/models:/models \
-v path/to/audios:/audios \ -v path/to/audios:/audios \
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav" whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
# transcribe an audio file in samples folder # transcribe an audio file in samples folder
docker run -it --rm \ docker run -it --rm \
-v path/to/models:/models \ -v path/to/models:/models \
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav" whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
``` ```
## Installing with Conan ## Installing with Conan
@ -483,8 +427,7 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
This is a naive example of performing real-time inference on audio from your microphone. This is a naive example of performing real-time inference on audio from your microphone.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously. The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
More info is available in [issue #10](https://github.com/ggml-org/whisper.cpp/issues/10). More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly.
```bash ```bash
cmake -B build -DWHISPER_SDL2=ON cmake -B build -DWHISPER_SDL2=ON
@ -572,7 +515,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
## Speaker segmentation via tinydiarize (experimental) ## Speaker segmentation via tinydiarize (experimental)
More information about this approach is available here: https://github.com/ggml-org/whisper.cpp/pull/1058 More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
Sample usage: Sample usage:
@ -636,7 +579,7 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
## Video comparison of different models ## Video comparison of different models
Use the [scripts/bench-wts.sh](https://github.com/ggml-org/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format: Use the [scripts/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format:
```bash ```bash
./scripts/bench-wts.sh samples/jfk.wav ./scripts/bench-wts.sh samples/jfk.wav
@ -653,7 +596,7 @@ In order to have an objective comparison of the performance of the inference acr
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
took to execute it. The results are summarized in the following Github issue: took to execute it. The results are summarized in the following Github issue:
[Benchmark results](https://github.com/ggml-org/whisper.cpp/issues/89) [Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py). Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py).
@ -680,24 +623,25 @@ You can download the converted models using the [models/download-ggml-model.sh](
or manually from here: or manually from here:
- https://huggingface.co/ggerganov/whisper.cpp - https://huggingface.co/ggerganov/whisper.cpp
- https://ggml.ggerganov.com
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md). For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md).
## [Bindings](https://github.com/ggml-org/whisper.cpp/discussions/categories/bindings) ## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggml-org/whisper.cpp/discussions/310) - [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggml-org/whisper.cpp/discussions/309) - [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn) - React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggml-org/whisper.cpp/discussions/312) - [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [x] Java: - [x] Java:
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni) - [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggml-org/whisper.cpp/discussions/507) - [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [x] Objective-C / Swift: [ggml-org/whisper.spm](https://github.com/ggml-org/whisper.spm) | [#313](https://github.com/ggml-org/whisper.cpp/discussions/313) - [x] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper) - [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
- [x] .NET: | [#422](https://github.com/ggml-org/whisper.cpp/discussions/422) - [x] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net) - [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper) - [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [x] Python: | [#9](https://github.com/ggml-org/whisper.cpp/issues/9) - [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython) - [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp) - [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11) - [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
@ -705,33 +649,6 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
- [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper) - [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
- [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity) - [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
## XCFramework
The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS,
and macOS. It can be used in Swift projects without the need to compile the
library from source. For examples:
```swift
// swift-tools-version: 5.10
// The swift-tools-version declares the minimum version of Swift required to build this package.
import PackageDescription
let package = Package(
name: "Whisper",
targets: [
.executableTarget(
name: "Whisper",
dependencies: [
"WhisperFramework"
]),
.binaryTarget(
name: "WhisperFramework",
url: "https://github.com/ggml-org/whisper.cpp/releases/download/v1.7.5/whisper-v1.7.5-xcframework.zip",
checksum: "c7faeb328620d6012e130f3d705c51a6ea6c995605f2df50f6e1ad68c59c6c4a"
)
]
)
```
## Examples ## Examples
There are various examples of using the library for different projects in the [examples](examples) folder. There are various examples of using the library for different projects in the [examples](examples) folder.
@ -750,13 +667,13 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp | | [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim | | [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture | | [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggml-org/whisper.cpp/issues/185) | | [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) | | [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess | | [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
## [Discussions](https://github.com/ggml-org/whisper.cpp/discussions) ## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic. If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
You can use the [Show and tell](https://github.com/ggml-org/whisper.cpp/discussions/categories/show-and-tell) category You can use the [Show and tell](https://github.com/ggerganov/whisper.cpp/discussions/categories/show-and-tell) category
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
[Frequently asked questions (#126)](https://github.com/ggml-org/whisper.cpp/discussions/126) discussion. [Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126) discussion.

View File

@ -11,11 +11,11 @@ UNAME_M := $(shell uname -m)
endif endif
GGML_METAL_PATH_RESOURCES := $(abspath ../..) GGML_METAL_PATH_RESOURCES := $(abspath ../..)
BUILD_DIR := build_go BUILD_DIR := build
MODELS_DIR := models MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*) EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include) INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include)
LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src:$(abspath ../../${BUILD_DIR}/ggml/src)) LIBRARY_PATH := $(abspath ../..)
ifeq ($(GGML_CUDA),1) ifeq ($(GGML_CUDA),1)
LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/ LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/
@ -29,10 +29,8 @@ endif
all: clean whisper examples all: clean whisper examples
whisper: mkdir whisper: mkdir
cmake -S ../.. -B ../../${BUILD_DIR} \ @echo Build whisper
-DCMAKE_BUILD_TYPE=Release \ @${MAKE} -C ../.. libwhisper.a
-DBUILD_SHARED_LIBS=OFF
cmake --build ../../${BUILD_DIR} --target whisper
test: model-small whisper modtidy test: model-small whisper modtidy
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)

View File

@ -31,7 +31,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
if err := context.Process(samples, nil, nil, nil); err != nil { if err := context.Process(samples, nil, nil); err != nil {
return err return err
} }
@ -51,7 +51,7 @@ func main() {
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with: In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
```bash ```bash
git clone https://github.com/ggml-org/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/go cd whisper.cpp/bindings/go
make test make test
``` ```
@ -98,7 +98,7 @@ The API Documentation:
Getting help: Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312) * Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License ## License

View File

@ -1,5 +1,5 @@
/* /*
github.com/ggml-org/whisper.cpp/bindings/go github.com/ggerganov/whisper.cpp/bindings/go
provides a speech-to-text service bindings for the Go programming language. provides a speech-to-text service bindings for the Go programming language.
*/ */
package whisper package whisper

View File

@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data // Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path) fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings() context.ResetTimings()
if err := context.Process(data, nil, cb, nil); err != nil { if err := context.Process(data, cb, nil); err != nil {
return err return err
} }

View File

@ -71,10 +71,6 @@ func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language()) return whisper.Whisper_lang_str(context.params.Language())
} }
func (context *context) DetectedLanguage() string {
return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id())
}
// Set translate flag // Set translate flag
func (context *context) SetTranslate(v bool) { func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v) context.params.SetTranslate(v)
@ -193,7 +189,6 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
// Process new sample data and return any errors // Process new sample data and return any errors
func (context *context) Process( func (context *context) Process(
data []float32, data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback, callNewSegment SegmentCallback,
callProgress ProgressCallback, callProgress ProgressCallback,
) error { ) error {
@ -208,8 +203,7 @@ func (context *context) Process(
// We don't do parallel processing at the moment // We don't do parallel processing at the moment
processors := 0 processors := 0
if processors > 1 { if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin, if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
func(new int) {
if callNewSegment != nil { if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments() num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new s0 := num_segments - new
@ -220,8 +214,7 @@ func (context *context) Process(
}); err != nil { }); err != nil {
return err return err
} }
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin, } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
func(new int) {
if callNewSegment != nil { if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments() num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new s0 := num_segments - new

View File

@ -88,37 +88,6 @@ func TestProcess(t *testing.T) {
context, err := model.NewContext() context, err := model.NewContext()
assert.NoError(err) assert.NoError(err)
err = context.Process(data, nil, nil, nil) err = context.Process(data, nil, nil)
assert.NoError(err) assert.NoError(err)
} }
func TestDetectedLanguage(t *testing.T) {
assert := assert.New(t)
fh, err := os.Open(SamplePath)
assert.NoError(err)
defer fh.Close()
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
buf, err := dec.FullPCMBuffer()
assert.NoError(err)
assert.Equal(uint16(1), dec.NumChans)
data := buf.AsFloat32Buffer().Data
model, err := whisper.New(ModelPath)
assert.NoError(err)
assert.NotNil(model)
defer model.Close()
context, err := model.NewContext()
assert.NoError(err)
err = context.Process(data, nil, nil, nil)
assert.NoError(err)
expectedLanguage := "en"
actualLanguage := context.DetectedLanguage()
assert.Equal(expectedLanguage, actualLanguage)
}

View File

@ -16,10 +16,6 @@ type SegmentCallback func(Segment)
// processing. It is called during the Process function // processing. It is called during the Process function
type ProgressCallback func(int) type ProgressCallback func(int)
// EncoderBeginCallback is the callback function for checking if we want to
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool
// Model is the interface to a whisper model. Create a new model with the // Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string) // function whisper.New(string)
type Model interface { type Model interface {
@ -35,13 +31,12 @@ type Model interface {
Languages() []string Languages() []string
} }
// Context is the speech recognition context. // Context is the speach recognition context.
type Context interface { type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual. IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language Language() string // Get language
DetectedLanguage() string // Get detected language
SetOffset(time.Duration) // Set offset SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration SetDuration(time.Duration) // Set duration
@ -63,7 +58,7 @@ type Context interface {
// Process mono audio data and return any errors. // Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the // If defined, newly generated segments are passed to the
// callback function during processing. // callback function during processing.
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error Process([]float32, SegmentCallback, ProgressCallback) error
// After process is called, return segments until the end of the stream // After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned. // is reached, when io.EOF is returned.

View File

@ -9,7 +9,7 @@ import (
// CGO // CGO
/* /*
#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ -fopenmp #cgo LDFLAGS: -lwhisper -lm -lstdc++ -fopenmp
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics #cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
#include <whisper.h> #include <whisper.h>
#include <stdlib.h> #include <stdlib.h>

View File

@ -52,7 +52,7 @@ public class Example {
In order to build, you need to have the JDK 8 or higher installed. Run the tests with: In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
```bash ```bash
git clone https://github.com/ggml-org/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java cd whisper.cpp/bindings/java
./gradlew build ./gradlew build

View File

@ -25,43 +25,25 @@ sourceSets {
} }
tasks.register('copyLibwhisperDynlib', Copy) { tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build/src' from '../../build'
include 'libwhisper.dylib' include 'libwhisper.dynlib'
into 'build/generated/resources/main' into 'build/generated/resources/main/darwin'
} }
tasks.register('copyLibwhisperSo', Copy) { tasks.register('copyLibwhisperSo', Copy) {
from '../../build/src' from '../../build'
include 'libwhisper.so' include 'libwhisper.so'
into 'build/generated/resources/main' into 'build/generated/resources/main/linux-x86-64'
} }
tasks.register('copyWhisperDLL', Copy) { tasks.register('copyWhisperDll', Copy) {
from '../../build/bin/Release' from '../../build/Release'
include 'whisper.dll' include 'whisper.dll'
into 'build/generated/resources/main' into 'build/generated/resources/main/windows-x86-64'
}
tasks.register('copyGGML_BASE_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml-base.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_CPU_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml-cpu.dll'
into 'build/generated/resources/main'
} }
tasks.register('copyLibs') { tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDLL, copyGGML_BASE_DLL, copyGGML_DLL, copyGGML_CPU_DLL dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
} }
test { test {
@ -73,12 +55,7 @@ java {
withJavadocJar() withJavadocJar()
} }
sourcesJar() {
dependsOn copyLibs
}
jar { jar {
dependsOn copyLibs
exclude '**/whisper_java.exp', '**/whisper_java.lib' exclude '**/whisper_java.exp', '**/whisper_java.lib'
} }
@ -90,9 +67,6 @@ tasks.withType(Test) {
useJUnitPlatform() useJUnitPlatform()
} }
test.dependsOn copyLibs
processResources.dependsOn copyLibs
dependencies { dependencies {
implementation "net.java.dev.jna:jna:5.13.0" implementation "net.java.dev.jna:jna:5.13.0"
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2" testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"

0
bindings/java/gradlew vendored Executable file → Normal file
View File

View File

@ -1,24 +0,0 @@
package io.github.ggerganov.whispercpp;
/**
* Presets for alignment heads in DTW token timestamps
*/
public class WhisperConstants {
// Alignment heads presets
public static final int WHISPER_AHEADS_NONE = 0;
public static final int WHISPER_AHEADS_TINY_EN = 1;
public static final int WHISPER_AHEADS_TINY = 2;
public static final int WHISPER_AHEADS_BASE_EN = 3;
public static final int WHISPER_AHEADS_BASE = 4;
public static final int WHISPER_AHEADS_SMALL_EN = 5;
public static final int WHISPER_AHEADS_SMALL = 6;
public static final int WHISPER_AHEADS_MEDIUM_EN = 7;
public static final int WHISPER_AHEADS_MEDIUM = 8;
public static final int WHISPER_AHEADS_LARGE_V1 = 9;
public static final int WHISPER_AHEADS_LARGE_V2 = 10;
public static final int WHISPER_AHEADS_LARGE_V3 = 11;
public static final int WHISPER_AHEADS_LARGE_V3_TURBO = 12;
public static final int WHISPER_AHEADS_CUSTOM = 13;
public static final int WHISPER_AHEADS_N_TOP_MOST = 14;
public static final int WHISPER_AHEADS_COUNT = 15;
}

View File

@ -1,9 +1,7 @@
package io.github.ggerganov.whispercpp; package io.github.ggerganov.whispercpp;
import com.sun.jna.NativeLong;
import com.sun.jna.Structure; import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference; import com.sun.jna.ptr.PointerByReference;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.ggml.GgmlType; import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel; import io.github.ggerganov.whispercpp.WhisperModel;
import io.github.ggerganov.whispercpp.params.WhisperContextParams; import io.github.ggerganov.whispercpp.params.WhisperContextParams;
@ -11,26 +9,33 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import java.util.List; import java.util.List;
public class WhisperContext extends Structure { public class WhisperContext extends Structure {
public NativeLong t_load_us; int t_load_us = 0;
public NativeLong t_start_us; int t_start_us = 0;
/** weight type (FP32 / FP16 / QX) */ /** weight type (FP32 / FP16 / QX) */
public GgmlType wtype = GgmlType.GGML_TYPE_F16; GgmlType wtype = GgmlType.GGML_TYPE_F16;
/** intermediate type (FP32 or FP16) */ /** intermediate type (FP32 or FP16) */
public GgmlType itype = GgmlType.GGML_TYPE_F16; GgmlType itype = GgmlType.GGML_TYPE_F16;
public WhisperContextParams.ByValue params; // WhisperModel model;
public PointerByReference model;
public Pointer model; // whisper_vocab vocab;
public Pointer vocab; // whisper_state * state = nullptr;
public Pointer state; public PointerByReference vocab;
public PointerByReference state;
/** populated by whisper_init_from_file_with_params() */ /** populated by whisper_init_from_file_with_params() */
public Pointer path_model; String path_model;
WhisperContextParams params;
@Override // public static class ByReference extends WhisperContext implements Structure.ByReference {
protected List<String> getFieldOrder() { // }
return List.of("t_load_us", "t_start_us", "wtype", "itype", //
"params", "model", "vocab", "state", "path_model"); // public static class ByValue extends WhisperContext implements Structure.ByValue {
} // }
//
// @Override
// protected List<String> getFieldOrder() {
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
// }
} }

View File

@ -43,11 +43,11 @@ public class WhisperCpp implements AutoCloseable {
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en") * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
* @param params - params to use when initialising the context * @param params - params to use when initialising the context
*/ */
public void initContext(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException { public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
initContextImpl(modelPath, params); initContextImpl(modelPath, params);
} }
private void initContextImpl(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException { private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
if (ctx != null) { if (ctx != null) {
lib.whisper_free(ctx); lib.whisper_free(ctx);
} }
@ -69,13 +69,15 @@ public class WhisperCpp implements AutoCloseable {
/** /**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc. * Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Returns a ByValue instance to ensure proper parameter passing to native code. * Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_context_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*/ */
public WhisperContextParams.ByValue getContextDefaultParams() { public WhisperContextParams getContextDefaultParams() {
WhisperContextParams.ByValue valueParams = new WhisperContextParams.ByValue( paramsPointer = lib.whisper_context_default_params_by_ref();
lib.whisper_context_default_params_by_ref()); WhisperContextParams params = new WhisperContextParams(paramsPointer);
valueParams.read(); params.read();
return valueParams; return params;
} }
/** /**
@ -86,7 +88,7 @@ public class WhisperCpp implements AutoCloseable {
* *
* @param strategy - GREEDY * @param strategy - GREEDY
*/ */
public WhisperFullParams.ByValue getFullDefaultParams(WhisperSamplingStrategy strategy) { public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
Pointer pointer; Pointer pointer;
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy. // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
@ -102,7 +104,7 @@ public class WhisperCpp implements AutoCloseable {
pointer = beamParamsPointer; pointer = beamParamsPointer;
} }
WhisperFullParams.ByValue params = new WhisperFullParams.ByValue(pointer); WhisperFullParams params = new WhisperFullParams(pointer);
params.read(); params.read();
return params; return params;
} }
@ -136,21 +138,15 @@ public class WhisperCpp implements AutoCloseable {
} }
/** /**
* Run the entire model: PCM -&gt; log mel spectrogram -&gt; encoder -&gt; decoder -&gt; text. * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context * Not thread safe for same context
* Uses the specified decoding strategy to obtain the text. * Uses the specified decoding strategy to obtain the text.
*/ */
public String fullTranscribe(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException { public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) { if (ctx == null) {
throw new IllegalStateException("Model not initialised"); throw new IllegalStateException("Model not initialised");
} }
/*
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
valueParams.read();
*/
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) { if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio"); throw new IOException("Failed to process audio");
} }
@ -167,17 +163,12 @@ public class WhisperCpp implements AutoCloseable {
return str.toString().trim(); return str.toString().trim();
} }
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException { public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) { if (ctx == null) {
throw new IllegalStateException("Model not initialised"); throw new IllegalStateException("Model not initialised");
} }
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue( if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
valueParams.read();
if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio"); throw new IOException("Failed to process audio");
} }

View File

@ -9,7 +9,6 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library { public interface WhisperCppJnaLibrary extends Library {
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class); WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
String whisper_print_system_info(); String whisper_print_system_info();
@ -39,7 +38,7 @@ public interface WhisperCppJnaLibrary extends Library {
* @param params Pointer to whisper_context_params * @param params Pointer to whisper_context_params
* @return Whisper context on success, null on failure * @return Whisper context on success, null on failure
*/ */
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams.ByValue params); Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);
/** /**
* Allocate (almost) all memory needed for the model by loading from a buffer. * Allocate (almost) all memory needed for the model by loading from a buffer.
@ -181,12 +180,12 @@ public interface WhisperCppJnaLibrary extends Library {
/** /**
* @return the id of the specified language, returns -1 if not found. * @return the id of the specified language, returns -1 if not found.
* Examples: * Examples:
* "de" -&gt; 2 * "de" -> 2
* "german" -&gt; 2 * "german" -> 2
*/ */
int whisper_lang_id(String lang); int whisper_lang_id(String lang);
/** @return the short string of the specified language id (e.g. 2 -&gt; "de"), returns nullptr if not found */ /** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
String whisper_lang_str(int id); String whisper_lang_str(int id);
/** /**
@ -269,21 +268,20 @@ public interface WhisperCppJnaLibrary extends Library {
void whisper_free_params(Pointer params); void whisper_free_params(Pointer params);
/** /**
* Run the entire model: PCM -&gt; log mel spectrogram -&gt; encoder -&gt; decoder -&gt; text * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context * Not thread safe for same context
* Uses the specified decoding strategy to obtain the text. * Uses the specified decoding strategy to obtain the text.
*/ */
int whisper_full(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples); int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
public int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams.ByValue params, float[] samples, int n_samples); int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
//int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state() // Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context // Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context. // Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases. // It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk. // However, the transcription accuracy can be worse at the beginning and end of each chunk.
int whisper_full_parallel(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples, int n_processors); int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
/** /**
* Number of generated text segments. * Number of generated text segments.

View File

@ -1,17 +0,0 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
/**
* Callback for aborting GGML computation
* Maps to the C typedef: bool (*ggml_abort_callback)(void * data)
*/
public interface GgmlAbortCallback extends Callback {
/**
* Return true to abort the computation, false to continue
*
* @param data User data passed to the callback
* @return true to abort, false to continue
*/
boolean invoke(com.sun.jna.Pointer data);
}

View File

@ -1,30 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import java.util.Arrays;
import java.util.List;
public class WhisperAhead extends Structure {
public int n_text_layer;
public int n_head;
public WhisperAhead() {
super();
}
public WhisperAhead(int textLayer, int head) {
super();
this.n_text_layer = textLayer;
this.n_head = head;
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("n_text_layer", "n_head");
}
public static class ByReference extends WhisperAhead implements Structure.ByReference {}
public static class ByValue extends WhisperAhead implements Structure.ByValue {}
}

View File

@ -1,41 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import java.util.Arrays;
import java.util.List;
public class WhisperAheads extends Structure {
public NativeLong n_heads;
public Pointer heads;
public WhisperAheads() {
super();
}
/**
* Create alignment heads from an array of WhisperAhead objects
*/
public void setHeads(WhisperAhead[] aheadsArray) {
this.n_heads = new NativeLong(aheadsArray.length);
int structSize = aheadsArray[0].size();
Memory mem = new Memory(structSize * aheadsArray.length);
for (int i = 0; i < aheadsArray.length; i++) {
aheadsArray[i].write();
byte[] buffer = aheadsArray[i].getPointer().getByteArray(0, structSize);
mem.write(i * structSize, buffer, 0, buffer.length);
}
this.heads = mem;
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("n_heads", "heads");
}
public static class ByReference extends WhisperAheads implements Structure.ByReference {}
public static class ByValue extends WhisperAheads implements Structure.ByValue {}
}

View File

@ -1,5 +1,7 @@
package io.github.ggerganov.whispercpp.params; package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*; import com.sun.jna.*;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -9,73 +11,21 @@ import java.util.List;
* whisper_context_default_params() * whisper_context_default_params()
*/ */
public class WhisperContextParams extends Structure { public class WhisperContextParams extends Structure {
public WhisperContextParams(Pointer p) { public WhisperContextParams(Pointer p) {
super(p); super(p);
} }
public WhisperContextParams() { /** Use GPU for inference Number (default = true) */
super();
}
/** Use GPU for inference (default = true) */
public CBool use_gpu; public CBool use_gpu;
/** Use flash attention (default = false) */ /** Use GPU for inference Number (default = true) */
public CBool flash_attn;
/** CUDA device to use (default = 0) */
public int gpu_device;
/** [EXPERIMENTAL] Enable token-level timestamps with DTW (default = false) */
public CBool dtw_token_timestamps;
/** [EXPERIMENTAL] Alignment heads preset for DTW */
public int dtw_aheads_preset;
/** Number of top layers to use for DTW when using WHISPER_AHEADS_N_TOP_MOST preset */
public int dtw_n_top;
public WhisperAheads.ByValue dtw_aheads;
/** DTW memory size (internal use) */
public NativeLong dtw_mem_size;
/** Use GPU for inference */
public void useGpu(boolean enable) { public void useGpu(boolean enable) {
use_gpu = enable ? CBool.TRUE : CBool.FALSE; use_gpu = enable ? CBool.TRUE : CBool.FALSE;
} }
/** Use flash attention */
public void useFlashAttn(boolean enable) {
flash_attn = enable ? CBool.TRUE : CBool.FALSE;
}
/** Enable DTW token-level timestamps */
public void enableDtwTokenTimestamps(boolean enable) {
dtw_token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** Set DTW alignment heads preset */
public void setDtwAheadsPreset(int preset) {
dtw_aheads_preset = preset;
}
@Override @Override
protected List<String> getFieldOrder() { protected List<String> getFieldOrder() {
return Arrays.asList( return Arrays.asList("use_gpu");
"use_gpu",
"flash_attn",
"gpu_device",
"dtw_token_timestamps",
"dtw_aheads_preset",
"dtw_n_top",
"dtw_aheads",
"dtw_mem_size"
);
}
public static class ByValue extends WhisperContextParams implements Structure.ByValue {
public ByValue() { super(); }
public ByValue(Pointer p) { super(p); }
} }
} }

View File

@ -5,7 +5,6 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
import io.github.ggerganov.whispercpp.callbacks.GgmlAbortCallback;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -17,12 +16,10 @@ import java.util.List;
*/ */
public class WhisperFullParams extends Structure { public class WhisperFullParams extends Structure {
public WhisperFullParams() {
super();
}
public WhisperFullParams(Pointer p) { public WhisperFullParams(Pointer p) {
super(p); super(p);
// super(p, ALIGN_MSVC);
// super(p, ALIGN_GNUC);
} }
/** Sampling strategy for whisper_full() function. */ /** Sampling strategy for whisper_full() function. */
@ -72,10 +69,10 @@ public class WhisperFullParams extends Structure {
single_segment = single ? CBool.TRUE : CBool.FALSE; single_segment = single ? CBool.TRUE : CBool.FALSE;
} }
/** Flag to print special tokens (e.g., &lt;SOT&gt;, &lt;EOT&gt;, &lt;BEG&gt;, etc.). (default = false) */ /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public CBool print_special; public CBool print_special;
/** Flag to print special tokens (e.g., &lt;SOT&gt;, &lt;EOT&gt;, &lt;BEG&gt;, etc.). (default = false) */ /** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public void printSpecial(boolean enable) { public void printSpecial(boolean enable) {
print_special = enable ? CBool.TRUE : CBool.FALSE; print_special = enable ? CBool.TRUE : CBool.FALSE;
} }
@ -132,14 +129,6 @@ public class WhisperFullParams extends Structure {
/** Maximum tokens per segment (0, default = no limit) */ /** Maximum tokens per segment (0, default = no limit) */
public int max_tokens; public int max_tokens;
/** [EXPERIMENTAL] Enable debug mode for extra info */
public CBool debug_mode;
/** Enable debug mode */
public void enableDebugMode(boolean enable) {
debug_mode = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */ /** Overwrite the audio context size (0 = use default). */
public int audio_ctx; public int audio_ctx;
@ -285,16 +274,6 @@ public class WhisperFullParams extends Structure {
*/ */
public Pointer encoder_begin_callback_user_data; public Pointer encoder_begin_callback_user_data;
/** Callback used to abort GGML computation */
public Pointer abort_callback;
/** User data for the abort_callback */
public Pointer abort_callback_user_data;
public void setAbortCallback(GgmlAbortCallback callback) {
abort_callback = CallbackReference.getFunctionPointer(callback);
}
/** /**
* Callback by each decoder to filter obtained logits. * Callback by each decoder to filter obtained logits.
* WhisperLogitsFilterCallback * WhisperLogitsFilterCallback
@ -331,28 +310,17 @@ public class WhisperFullParams extends Structure {
@Override @Override
protected List<String> getFieldOrder() { protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"offset_ms", "duration_ms", "translate", "no_context", "no_context", "single_segment", "no_timestamps",
"no_timestamps", "single_segment", "print_special", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"print_progress", "print_realtime", "print_timestamps", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"token_timestamps", "thold_pt", "thold_ptsum", "max_len", "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"split_on_word", "max_tokens", "debug_mode", "audio_ctx", "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
"tdrz_enable", "suppress_regex", "initial_prompt", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"prompt_tokens", "prompt_n_tokens", "language", "detect_language", "new_segment_callback", "new_segment_callback_user_data",
"suppress_blank", "suppress_nst", "temperature",
"max_initial_ts", "length_penalty", "temperature_inc",
"entropy_thold", "logprob_thold", "no_speech_thold", "greedy",
"beam_search", "new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data", "progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data", "encoder_begin_callback", "encoder_begin_callback_user_data",
"abort_callback", "abort_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data", "logits_filter_callback", "logits_filter_callback_user_data",
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty"); "grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
} }
public static class ByValue extends WhisperFullParams implements Structure.ByValue {
public ByValue() { super(); }
public ByValue(Pointer p) { super(p); }
}
} }

View File

@ -76,7 +76,7 @@ class WhisperCppTest {
float[] floats = new float[b.length / 2]; float[] floats = new float[b.length / 2];
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); //WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress)); params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE; params.print_progress = CBool.FALSE;
//params.initial_prompt = "and so my fellow Americans um, like"; //params.initial_prompt = "and so my fellow Americans um, like";

View File

@ -33,9 +33,6 @@ mkdir build-em && cd build-em
emcmake cmake .. && make -j emcmake cmake .. && make -j
# run test # run test
node ../tests/test-whisper.js
# For Node.js versions prior to v16.4.0, experimental features need to be enabled:
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
# publish npm package # publish npm package

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.7.5", "version": "1.7.4",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

View File

@ -1,6 +1,3 @@
LICENSE LICENSE
pkg/ pkg/
lib/whisper.* lib/whisper.*
ext/sources/*
!ext/sources/CMakeGraphVizOptions.cmake
ext/mkmf.log

View File

@ -16,18 +16,6 @@ If bundler is not being used to manage dependencies, install the gem by executin
$ gem install whispercpp $ gem install whispercpp
You can pass build options for whisper.cpp, for instance:
$ bundle config build.whispercpp --enable-ggml-cuda
or,
$ gem install whispercpp -- --enable-ggml-cuda
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options.
For boolean options like `GGML_CUDA`, the README says `-DGGML_CUDA=1`. You need strip `-D`, prepend `--enable-` for `1` or `ON` (`--disable-` for `0` or `OFF`) and make it kebab-case: `--enable-ggml-cuda`.
For options which require arguments like `CMAKE_CUDA_ARCHITECTURES`, the README says `-DCMAKE_CUDA_ARCHITECTURES="86"`. You need strip `-D`, prepend `--`, make it kebab-case, append `=` and append argument: `--cmake-cuda-architectures="86"`.
Usage Usage
----- -----
@ -240,7 +228,7 @@ The second argument `samples` may be an array, an object with `length` and `each
Development Development
----------- -----------
% git clone https://github.com/ggml-org/whisper.cpp.git % git clone https://github.com/ggerganov/whisper.cpp.git
% cd whisper.cpp/bindings/ruby % cd whisper.cpp/bindings/ruby
% rake test % rake test
@ -253,5 +241,5 @@ License
The same to [whisper.cpp][]. The same to [whisper.cpp][].
[whisper.cpp]: https://github.com/ggml-org/whisper.cpp [whisper.cpp]: https://github.com/ggerganov/whisper.cpp
[models]: https://github.com/ggml-org/whisper.cpp/tree/master/models [models]: https://github.com/ggerganov/whisper.cpp/tree/master/models

View File

@ -3,15 +3,11 @@ require "bundler/gem_tasks"
require "rake/testtask" require "rake/testtask"
require_relative "extsources" require_relative "extsources"
SOURCES_DIR = "ext/sources"
SOURCES = FileList[] SOURCES = FileList[]
EXTSOURCES.each do |src| EXTSOURCES.each do |src|
basename = src.pathmap("%f") basename = src.pathmap("%f")
dest = basename == "LICENSE" ? basename dest = basename == "LICENSE" ? basename : src.pathmap("%{../..,ext}p")
: src.pathmap("%{\\.\\./\\.\\.,#{SOURCES_DIR}}p")
.pathmap("%{\\.\\./javascript,#{SOURCES_DIR}/bindings/javascript}p")
dir = dest.pathmap("%d") dir = dest.pathmap("%d")
file src file src
directory dir directory dir
@ -22,6 +18,7 @@ EXTSOURCES.each do |src|
end end
CLEAN.include SOURCES CLEAN.include SOURCES
CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
SRC = FileList["ext/*.{c,cpp,h}"] SRC = FileList["ext/*.{c,cpp,h}"]
@ -39,20 +36,6 @@ file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
ruby "extconf.rb" ruby "extconf.rb"
end end
end end
if File.exist? "ext/Makefile"
task :make_clean do
cd "ext" do
sh "make", "clean"
end
end
task clean: :make_clean
task :make_distclean do
cd "ext" do
sh "make", "distclean"
end
end
task clobber: :make_distclean
end
file SO_FILE => "ext/Makefile" do |t| file SO_FILE => "ext/Makefile" do |t|
chdir "ext" do chdir "ext" do

9
bindings/ruby/ext/cpu.mk Normal file
View File

@ -0,0 +1,9 @@
ggml/src/ggml-cpu/ggml-cpu-cpp.o: \
ggml/src/ggml-cpu/ggml-cpu.cpp \
ggml/include/ggml-backend.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/src/ggml-backend-impl.h \
ggml/include/ggml-cpu.h \
ggml/src/ggml-impl.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -1,61 +0,0 @@
require "tsort"
class Dependencies
def initialize(cmake, options)
@cmake = cmake
@options = options
generate_dot
@libs = parse_dot
end
def to_s
@libs.join(" ")
end
private
def dot_path
File.join(__dir__, "build", "whisper.cpp.dot")
end
def generate_dot
system @cmake, "-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF", @options.to_s, exception: true
end
def parse_dot
static_lib_shape = nil
nodes = {}
depends = Hash.new {|h, k| h[k] = []}
class << depends
include TSort
alias tsort_each_node each_key
def tsort_each_child(node, &block)
fetch(node, []).each(&block)
end
end
File.open(dot_path).each_line do |line|
case line
when /\[\s*label\s*=\s*"Static Library"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]/
static_lib_shape = $~[:shape]
when /\A\s*"(?<node>\w+)"\s*\[\s*label\s*=\s*"(?<label>\S+)"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]\s*;\s*\z/
node = $~[:node]
label = $~[:label]
shape = $~[:shape]
nodes[node] = [label, shape]
when /\A\s*"(?<depender>\w+)"\s*->\s*"(?<dependee>\w+)"/
depender = $~[:depender]
dependee = $~[:dependee]
depends[depender] ||= []
depends[depender] << dependee
end
end
depends.tsort.filter_map {|node|
label, shape = nodes[node]
shape == static_lib_shape ? label : nil
}.collect {|lib| "lib#{lib}.a"}
.reverse
end
end

View File

@ -1,22 +1,208 @@
require "mkmf" require 'mkmf'
require_relative "options"
require_relative "dependencies"
cmake = find_executable("cmake") || abort # need to use c++ compiler flags
options = Options.new $CXXFLAGS << ' -std=c++17'
have_library("gomp") rescue nil
libs = Dependencies.new(cmake, options)
$INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" $LDFLAGS << ' -lstdc++'
$LOCAL_LIBS << " #{libs}"
$cleanfiles << " build #{libs}"
create_makefile "whisper" do |conf| # Set to true when building binary gems
conf << <<~EOF if enable_config('static-stdlib', false)
$(TARGET_SO): #{libs} $LDFLAGS << ' -static-libgcc -static-libstdc++'
#{libs}: cmake-targets end
cmake-targets:
#{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} if enable_config('march-tune-native', false)
#{"\t"}#{cmake} --build build --config Release --target common whisper $CFLAGS << ' -march=native -mtune=native'
EOF $CXXFLAGS << ' -march=native -mtune=native'
end
if ENV['WHISPER_METAL']
$GGML_METAL ||= true
$DEPRECATE_WARNING ||= true
end
$UNAME_S = `uname -s`.chomp
$UNAME_P = `uname -p`.chomp
$UNAME_M = `uname -m`.chomp
if $UNAME_S == 'Darwin'
unless ENV['GGML_NO_METAL']
$GGML_METAL ||= true
end
$GGML_NO_OPENMP ||= true
end
if $GGML_METAL
$GGML_METAL_EMBED_LIBRARY = true
end
$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples -DGGML_USE_CPU'
$MK_CFLAGS = '-std=c11 -fPIC'
$MK_CXXFLAGS = '-std=c++17 -fPIC'
$MK_NVCCFLAGS = '-std=c++17'
$MK_LDFLAGS = ''
$OBJ_GGML = []
$OBJ_WHISPER = []
$OBJ_COMMON = []
$OBJ_SDL = []
$MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600'
if $UNAME_S == 'Linux'
$MK_CPPFLAGS << ' -D_GNU_SOURCE'
end
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -D_DARWIN_C_SOURCE'
end
if ENV['WHISPER_DEBUG']
$MK_CFLAGS << ' -O0 -g'
$MK_CXXFLAGS << ' -O0 -g'
$MK_LDFLAGS << ' -g'
$MK_NVCCFLAGS << ' -O0 -g'
else
$MK_CPPFLAGS << ' -DNDEBUG'
$MK_CFLAGS << ' -O3'
$MK_CXXFLAGS << ' -O3'
$MK_NVCCFLAGS << ' -O3'
end
$WARN_FLAGS =
' -Wall' <<
' -Wextra' <<
' -Wpedantic' <<
' -Wcast-qual' <<
' -Wno-unused-function'
$MK_CFLAGS <<
$WARN_FLAGS <<
' -Wshadow' <<
' -Wstrict-prototypes' <<
' -Wpointer-arith' <<
' -Wmissing-prototypes' <<
' -Werror=implicit-int' <<
' -Werror=implicit-function-declaration'
$MK_CXXFLAGS <<
$WARN_FLAGS <<
' -Wmissing-declarations' <<
' -Wmissing-noreturn'
unless `#{cc_command} #{$LDFLAGS} -Wl,-v 2>&1`.chomp.include? 'dyld-1015.7'
$MK_CPPFLAGS << ' -DHAVE_BUGGY_APPLE_LINKER'
end
if %w[Linux Darwin FreeBSD NetBSD OpenBSD Haiku].include? $UNAME_S
$MK_CFLAGS << ' -pthread'
$MK_CXXFLAGS << ' -pthread'
end
unless $_WIN32
$DSO_EXT = '.so'
else
$DSO_EXT = '.dll'
end
unless ENV['RISCV']
if %w[x86_64 i686 amd64].include? $UNAME_M
$HOST_CXXFLAGS ||= ''
$MK_CFLAGS << ' -march=native -mtune=native'
$HOST_CXXFLAGS << ' -march=native -mtune=native'
end
else
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
end
unless ENV['GGML_NO_ACCELERATE']
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE'
$MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
$MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
$MK_LDFLAGS << ' -framework Accelerate'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
end
if ENV['GGML_OPENBLAS']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if ENV['GGML_OPENBLAS64']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas64`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if $GGML_METAL
$MK_CPPFLAGS << ' -DGGML_USE_METAL'
$MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal.o'
if ENV['GGML_METAL_NDEBUG']
$MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
end
if $GGML_METAL_EMBED_LIBRARY
$MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal-embed.o'
end
end
$OBJ_GGML <<
'ggml/src/ggml.o' <<
'ggml/src/ggml-alloc.o' <<
'ggml/src/ggml-backend.o' <<
'ggml/src/ggml-backend-reg.o' <<
'ggml/src/ggml-opt.o' <<
'ggml/src/ggml-quants.o' <<
'ggml/src/ggml-threading.o' <<
'ggml/src/ggml-cpu/ggml-cpu.o' <<
'ggml/src/ggml-cpu/ggml-cpu-cpp.o' <<
'ggml/src/ggml-cpu/ggml-cpu-aarch64.o' <<
'ggml/src/ggml-cpu/ggml-cpu-hbm.o' <<
'ggml/src/ggml-cpu/ggml-cpu-quants.o' <<
'ggml/src/ggml-cpu/ggml-cpu-traits.o'
$OBJ_WHISPER <<
'src/whisper.o' <<
'examples/common.o' <<
'examples/common-whisper.o'
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
$objs <<
"ruby_whisper.o" <<
"ruby_whisper_context.o" <<
"ruby_whisper_transcribe.o" <<
"ruby_whisper_params.o" <<
"ruby_whisper_error.o" <<
"ruby_whisper_segment.o" <<
"ruby_whisper_model.o"
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
$BASE_CXXFLAGS = "#{$MK_CXXFLAGS} #{$CXXFLAGS}"
$CXXFLAGS = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}"
$NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}"
$LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}"
create_makefile('whisper')
File.open 'Makefile', 'a' do |file|
file.puts 'include scripts/get-flags.mk'
file.puts 'include cpu.mk'
if $GGML_METAL
file.puts 'include metal.mk'
if $GGML_METAL_EMBED_LIBRARY
file.puts 'include metal-embed.mk'
end
end
end end

View File

@ -0,0 +1,17 @@
ggml/src/ggml-metal/ggml-metal-embed.o: \
ggml/src/ggml-metal/ggml-metal.metal \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/src/ggml-common.h
@echo "Embedding Metal library"
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
$(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@
@rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s
@rmdir ${TEMP_ASSEMBLY}

View File

@ -0,0 +1,6 @@
ggml/src/ggml-metal/ggml-metal.o: \
ggml/src/ggml-metal/ggml-metal.m \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h
$(CC) $(CFLAGS) -c $< -o $@

View File

@ -1,219 +0,0 @@
class Options
def initialize
@options = {}
@pending_options = []
@ignored_options = []
configure
end
def help
@options
.collect_concat {|name, (type, value)|
option = option_name(name)
if type == :bool
["--enable-#{option}", "--disable-#{option}"]
else
"--#{option}=#{type.upcase}"
end
}
.join($/)
end
def to_s
@options
.reject {|name, (type, value)| value.nil?}
.collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"}
.join(" ")
end
def cmake_options
return @cmake_options if @cmake_options
output = nil
Dir.chdir __dir__ do
output = `cmake -S sources -B build -L`
end
started = false
@cmake_options = output.lines.filter_map {|line|
if line.chomp == "-- Cache values"
started = true
next
end
next unless started
option, value = line.chomp.split("=", 2)
name, type = option.split(":", 2)
[name, type, value]
}
end
def missing_options
cmake_options.collect {|name, type, value| name} -
@options.keys - @pending_options - @ignored_options
end
def extra_options
@options.keys + @pending_options - @ignored_options -
cmake_options.collect {|name, type, value| name}
end
private
def configure
filepath "ACCELERATE_FRAMEWORK"
ignored "BUILD_SHARED_LIBS"
ignored "BUILD_TESTING"
ignored "CMAKE_BUILD_TYPE"
ignored "CMAKE_INSTALL_PREFIX"
string "CMAKE_OSX_ARCHITECTURES"
ignored "CMAKE_OSX_DEPLOYMENT_TARGET"
string "CMAKE_OSX_SYSROOT"
filepath "FOUNDATION_LIBRARY"
bool "GGML_ACCELERATE"
bool "GGML_ALL_WARNINGS_3RD_PARTY"
bool "GGML_AMX_BF16"
bool "GGML_AMX_INT8"
bool "GGML_AMX_TILE"
bool "GGML_AVX"
bool "GGML_AVX2"
bool "GGML_AVX512"
bool "GGML_AVX512_BF16"
bool "GGML_AVX512_VBMI"
bool "GGML_AVX512_VNNI"
bool "GGML_AVX_VNNI"
ignored "GGML_BACKEND_DL"
ignored "GGML_BIN_INSTALL_DIR"
bool "GGML_BLAS"
string "GGML_BLAS_VENDOR"
bool "GGML_BMI2"
ignored "GGML_BUILD_EXAMPLES"
ignored "GGML_BUILD_TESTS"
filepath "GGML_CCACHE_FOUND"
bool "GGML_CPU"
bool "GGML_CPU_AARCH64"
ignored "GGML_CPU_ALL_VARIANTS"
string "GGML_CPU_ARM_ARCH"
bool "GGML_CPU_HBM"
bool "GGML_CPU_KLEIDIAI"
string "GGML_CPU_POWERPC_CPUTYPE"
bool "GGML_CUDA"
string "GGML_CUDA_COMPRESSION_MODE"
bool "GGML_CUDA_F16"
bool "GGML_CUDA_FA"
bool "GGML_CUDA_FA_ALL_QUANTS"
bool "GGML_CUDA_FORCE_CUBLAS"
bool "GGML_CUDA_FORCE_MMQ"
ignored "GGML_CUDA_GRAPHS"
bool "GGML_CUDA_NO_PEER_COPY"
bool "GGML_CUDA_NO_VMM"
string "GGML_CUDA_PEER_MAX_BATCH_SIZE"
bool "GGML_F16C"
bool "GGML_FMA"
bool "GGML_GPROF"
bool "GGML_HIP"
bool "GGML_HIP_GRAPHS"
bool "GGML_HIP_NO_VMM"
bool "GGML_HIP_ROCWMMA_FATTN"
ignored "GGML_INCLUDE_INSTALL_DIR"
bool "GGML_KOMPUTE"
bool "GGML_LASX"
ignored "GGML_LIB_INSTALL_DIR"
ignored "GGML_LLAMAFILE"
bool "GGML_LSX"
bool "GGML_LTO"
bool "GGML_METAL"
bool "GGML_METAL_EMBED_LIBRARY"
string "GGML_METAL_MACOSX_VERSION_MIN"
bool "GGML_METAL_NDEBUG"
bool "GGML_METAL_SHADER_DEBUG"
string "GGML_METAL_STD"
bool "GGML_METAL_USE_BF16"
bool "GGML_MUSA"
bool "GGML_NATIVE"
bool "GGML_OPENCL"
bool "GGML_OPENCL_EMBED_KERNELS"
bool "GGML_OPENCL_PROFILING"
string "GGML_OPENCL_TARGET_VERSION"
bool "GGML_OPENCL_USE_ADRENO_KERNELS"
bool "GGML_OPENMP"
bool "GGML_RPC"
bool "GGML_RVV"
bool "GGML_RV_ZFH"
pending "GGML_SCCACHE_FOUND"
string "GGML_SCHED_MAX_COPIES"
bool "GGML_SSE42"
ignored "GGML_STATIC"
bool "GGML_SYCL"
string "GGML_SYCL_DEVICE_ARCH"
bool "GGML_SYCL_F16"
bool "GGML_SYCL_GRAPH"
string "GGML_SYCL_TARGET"
bool "GGML_VULKAN"
bool "GGML_VULKAN_CHECK_RESULTS"
bool "GGML_VULKAN_DEBUG"
bool "GGML_VULKAN_MEMORY_DEBUG"
bool "GGML_VULKAN_PERF"
ignored "GGML_VULKAN_RUN_TESTS"
filepath "GGML_VULKAN_SHADERS_GEN_TOOLCHAIN"
bool "GGML_VULKAN_SHADER_DEBUG_INFO"
pending "GGML_VULKAN_VALIDATE"
bool "GGML_VXE"
filepath "GIT_EXE"
filepath "MATH_LIBRARY"
filepath "METALKIT_FRAMEWORK"
filepath "METAL_FRAMEWORK"
bool "WHISPER_ALL_WARNINGS"
bool "WHISPER_ALL_WARNINGS_3RD_PARTY"
ignored "WHISPER_BIN_INSTALL_DIR"
ignored "WHISPER_BUILD_EXAMPLES"
ignored "WHISPER_BUILD_SERVER"
ignored"WHISPER_BUILD_TESTS"
bool "WHISPER_CCACHE"
bool "WHISPER_COREML"
bool "WHISPER_COREML_ALLOW_FALLBACK"
ignored "WHISPER_CURL"
bool "WHISPER_FATAL_WARNINGS"
ignored "WHISPER_FFMPEG"
ignored "WHISPER_INCLUDE_INSTALL_DIR"
ignored "WHISPER_LIB_INSTALL_DIR"
bool "WHISPER_OPENVINO"
bool "WHISPER_SANITIZE_ADDRESS"
bool "WHISPER_SANITIZE_THREAD"
bool "WHISPER_SANITIZE_UNDEFINED"
ignored "WHISPER_SDL2"
pending "WHISPER_USE_SYSTEM_GGML"
end
def option_name(name)
name.downcase.gsub("_", "-")
end
def bool(name)
option = option_name(name)
value = enable_config(option)
@options[name] = [:bool, value]
end
def string(name, type=:string)
option = "--#{option_name(name)}"
value = arg_config(option)
raise "String expected for #{option}" if value == true || value&.empty?
@options[name] = [type, value]
end
def path(name)
string(name, :path)
end
def filepath(name)
string(name, :filepath)
end
def pending(name)
@pending_options << name
end
def ignored(name)
@ignored_options << name
end
end

View File

@ -19,7 +19,6 @@ typedef struct {
bool diarize; bool diarize;
ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container; ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_params; } ruby_whisper_params;

View File

@ -26,7 +26,7 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32 #define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
extern VALUE cParams; extern VALUE cParams;
@ -63,8 +63,6 @@ static ID id_new_segment_callback;
static ID id_new_segment_callback_user_data; static ID id_new_segment_callback_user_data;
static ID id_progress_callback; static ID id_progress_callback;
static ID id_progress_callback_user_data; static ID id_progress_callback_user_data;
static ID id_encoder_begin_callback;
static ID id_encoder_begin_callback_user_data;
static ID id_abort_callback; static ID id_abort_callback;
static ID id_abort_callback_user_data; static ID id_abort_callback_user_data;
@ -128,33 +126,6 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
} }
} }
static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
bool is_aborted = false;
VALUE result;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
is_aborted = true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return !is_aborted;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
is_aborted = true;
}
}
return !is_aborted;
}
static bool abort_callback(void * user_data) { static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) { if (!NIL_P(container->callback)) {
@ -190,12 +161,6 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.progress_callback_user_data = rwp->progress_callback_container; rwp->params.progress_callback_user_data = rwp->progress_callback_container;
} }
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
rwp->encoder_begin_callback_container->context = context;
rwp->params.encoder_begin_callback = encoder_begin_callback;
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = context; rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback; rwp->params.abort_callback = abort_callback;
@ -208,7 +173,6 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
{ {
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container); rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
} }
@ -234,7 +198,6 @@ ruby_whisper_params_allocate(VALUE klass)
rwp->diarize = false; rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
} }
@ -886,57 +849,6 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
rwp->progress_callback_container->user_data = value; rwp->progress_callback_container->user_data = value;
return value; return value;
} }
static VALUE
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->callback;
}
/*
* Sets encoder begin callback, called when the encoder starts.
*
* params.encoder_begin_callback = ->(context, _, user_data) {
* # ...
* }
*
* call-seq:
* encoder_begin_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of encoder begin callback.
*
* call-seq:
* encoder_begin_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->user_data = value;
return value;
}
static VALUE static VALUE
ruby_whisper_params_get_abort_callback(VALUE self) ruby_whisper_params_get_abort_callback(VALUE self)
{ {
@ -1006,7 +918,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
return self; return self;
} }
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values); rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) { for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
@ -1046,8 +958,6 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(new_segment_callback_user_data) SET_PARAM_IF_SAME(new_segment_callback_user_data)
SET_PARAM_IF_SAME(progress_callback) SET_PARAM_IF_SAME(progress_callback)
SET_PARAM_IF_SAME(progress_callback_user_data) SET_PARAM_IF_SAME(progress_callback_user_data)
SET_PARAM_IF_SAME(encoder_begin_callback)
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
SET_PARAM_IF_SAME(abort_callback) SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data) SET_PARAM_IF_SAME(abort_callback_user_data)
} }
@ -1098,26 +1008,6 @@ ruby_whisper_params_on_progress(VALUE self)
return Qnil; return Qnil;
} }
/*
* Hook called when the encoder starts.
*
* whisper.on_encoder_begin do
* # ...
* end
*
* call-seq:
* on_encoder_begin { ... }
*/
static VALUE
ruby_whisper_params_on_encoder_begin(VALUE self)
{
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
return Qnil;
}
/* /*
* Call block to determine whether abort or not. Return +true+ when you want to abort. * Call block to determine whether abort or not. Return +true+ when you want to abort.
* *
@ -1178,13 +1068,10 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(new_segment_callback_user_data, 25) DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26) DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27) DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(encoder_begin_callback, 28) DEFINE_PARAM(abort_callback, 28)
DEFINE_PARAM(encoder_begin_callback_user_data, 29) DEFINE_PARAM(abort_callback_user_data, 29)
DEFINE_PARAM(abort_callback, 30)
DEFINE_PARAM(abort_callback_user_data, 31)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
} }

View File

@ -50,16 +50,15 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self; return self;
} }
// Commented out because it is work in progress {
// { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
// return !is_aborted; return !is_aborted;
// }; };
// rwp->params.encoder_begin_callback_user_data = &is_aborted; rwp->params.encoder_begin_callback_user_data = &is_aborted;
// } }
register_callbacks(rwp, &self); register_callbacks(rwp, &self);

View File

@ -1,8 +0,0 @@
set(GRAPHVIZ_EXECUTABLES FALSE)
set(GRAPHVIZ_STATIC_LIBS TRUE)
set(GRAPHVIZ_SHARED_LIBS FALSE)
set(GRAPHVIZ_MODULE_LIBS FALSE)
set(GRAPHVIZ_INTERFACE_LIBS FALSE)
set(GRAPHVIZ_OBJECT_LIBS FALSE)
set(GRAPHVIZ_UNKNOWN_LIBS FALSE)
set(GRAPHVIZ_GENERATE_DEPENDERS FALSE)

View File

@ -1,34 +1,6 @@
ignored_dirs = %w[ require "yaml"
.devops
examples/wchess/wchess.wasm
examples/whisper.android
examples/whisper.android.java
examples/whisper.objc
examples/whisper.swiftui
grammars
models
samples
scripts
]
ignored_files = %w[
AUTHORS
Makefile
README.md
README_sycl.md
.gitignore
.gitmodules
whisper.nvim
twitch.sh
yt-wsp.sh
]
EXTSOURCES = sources = `git ls-files -z ../..`.split("\x0")
`git ls-files -z ../..`.split("\x0") paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
.select {|file| paths.delete "bindings/ruby/**"
basename = File.basename(file) EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
ignored_dirs.all? {|dir| !file.start_with?("../../#{dir}")} &&
!ignored_files.include?(basename) &&
(file.start_with?("../..") || file.start_with?("../javascript")) &&
(!file.start_with?("../../.github/") || basename == "bindings-ruby.yml")
}

View File

@ -34,7 +34,7 @@ module Whisper
when /darwin/ when /darwin/
Pathname(Dir.home)/"Library/Caches" Pathname(Dir.home)/"Library/Caches"
else else
ENV.key?("XDG_CACHE_HOME") ? Pathname(ENV["XDG_CACHE_HOME"]) : Pathname(Dir.home)/".cache" ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
end end
base/"whisper.cpp" base/"whisper.cpp"
end end
@ -55,8 +55,6 @@ module Whisper
when Net::HTTPNotModified when Net::HTTPNotModified
# noop # noop
when Net::HTTPOK when Net::HTTPOK
return if !response.key?("last-modified") && cache_path.exist?
download response download response
when Net::HTTPRedirection when Net::HTTPRedirection
request URI(response["location"]), headers request URI(response["location"]), headers

View File

@ -7,7 +7,6 @@ module Whisper
type log_callback = ^(Integer level, String message, Object user_data) -> void type log_callback = ^(Integer level, String message, Object user_data) -> void
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
LOG_LEVEL_NONE: Integer LOG_LEVEL_NONE: Integer
@ -24,20 +23,9 @@ module Whisper
def self.log_set: (log_callback, Object? user_data) -> log_callback def self.log_set: (log_callback, Object? user_data) -> log_callback
class Context class Context
def self.new: (path | ::URI::HTTP) -> instance def self.new: (string | _ToPath | ::URI::HTTP) -> instance
# transcribe a single file
# can emit to a block results
#
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
#
def transcribe: (string, Params) -> self def transcribe: (string, Params) -> self
| (string, Params) { (String) -> void } -> self | (string, Params) { (String) -> void } -> self
def model_n_vocab: () -> Integer def model_n_vocab: () -> Integer
def model_n_audio_ctx: () -> Integer def model_n_audio_ctx: () -> Integer
def model_n_audio_state: () -> Integer def model_n_audio_state: () -> Integer
@ -46,72 +34,19 @@ module Whisper
def model_n_mels: () -> Integer def model_n_mels: () -> Integer
def model_ftype: () -> Integer def model_ftype: () -> Integer
def model_type: () -> String def model_type: () -> String
# Yields each Whisper::Segment:
#
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
#
# Returns an Enumerator if no block given:
#
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment] | () -> Enumerator[Segment]
def model: () -> Model def model: () -> Model
def full_get_segment: (Integer nth) -> Segment def full_get_segment: (Integer nth) -> Segment
def full_n_segments: () -> Integer def full_n_segments: () -> Integer
# Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
#
def full_lang_id: () -> Integer def full_lang_id: () -> Integer
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer) -> Integer def full_get_segment_t0: (Integer) -> Integer
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer) -> Integer def full_get_segment_t1: (Integer) -> Integer
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
#
# full_get_segment_speacker_turn_next(3) # => true
#
def full_get_segment_speaker_turn_next: (Integer) -> (true | false) def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
# Text of a segment indexed by +segment_index+.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer) -> String def full_get_segment_text: (Integer) -> String
def full_get_segment_no_speech_prob: (Integer) -> Float def full_get_segment_no_speech_prob: (Integer) -> Float
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
# Not thread safe for same context
# Uses the specified decoding strategy to obtain the text.
#
# The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
#
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self
# Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
# Result is stored in the default state of the context
# Not thread safe if executed in parallel on the same context.
# It seems this approach can offer some speedup in some cases.
# However, the transcription accuracy can be worse at the beginning and end of each chunk.
#
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
@ -147,223 +82,71 @@ module Whisper
?new_segment_callback_user_data: Object, ?new_segment_callback_user_data: Object,
?progress_callback: progress_callback, ?progress_callback: progress_callback,
?progress_callback_user_data: Object, ?progress_callback_user_data: Object,
?encoder_begin_callback: encoder_begin_callback,
?encoder_begin_callback_user_data: Object,
?abort_callback: abort_callback, ?abort_callback: abort_callback,
?abort_callback_user_data: Object ?abort_callback_user_data: Object
) -> instance ) -> instance
# params.language = "auto" | "en", etc...
#
def language=: (String) -> String # TODO: Enumerate lang names def language=: (String) -> String # TODO: Enumerate lang names
def language: () -> String def language: () -> String
def translate=: (boolish) -> boolish def translate=: (boolish) -> boolish
def translate: () -> (true | false) def translate: () -> (true | false)
def no_context=: (boolish) -> boolish def no_context=: (boolish) -> boolish
# If true, does not use past transcription (if any) as initial prompt for the decoder.
#
def no_context: () -> (true | false) def no_context: () -> (true | false)
def single_segment=: (boolish) -> boolish def single_segment=: (boolish) -> boolish
# If true, forces single segment output (useful for streaming).
#
def single_segment: () -> (true | false) def single_segment: () -> (true | false)
def print_special=: (boolish) -> boolish def print_special=: (boolish) -> boolish
# If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
#
def print_special: () -> (true | false) def print_special: () -> (true | false)
def print_progress=: (boolish) -> boolish def print_progress=: (boolish) -> boolish
# If true, prints progress information.
#
def print_progress: () -> (true | false) def print_progress: () -> (true | false)
def print_realtime=: (boolish) -> boolish def print_realtime=: (boolish) -> boolish
# If true, prints results from within whisper.cpp. (avoid it, use callback instead)
#
def print_realtime: () -> (true | false) def print_realtime: () -> (true | false)
# If true, prints timestamps for each text segment when printing realtime.
#
def print_timestamps=: (boolish) -> boolish def print_timestamps=: (boolish) -> boolish
def print_timestamps: () -> (true | false) def print_timestamps: () -> (true | false)
def suppress_blank=: (boolish) -> boolish def suppress_blank=: (boolish) -> boolish
# If true, suppresses blank outputs.
#
def suppress_blank: () -> (true | false) def suppress_blank: () -> (true | false)
def suppress_nst=: (boolish) -> boolish def suppress_nst=: (boolish) -> boolish
# If true, suppresses non-speech-tokens.
#
def suppress_nst: () -> (true | false) def suppress_nst: () -> (true | false)
def token_timestamps=: (boolish) -> boolish def token_timestamps=: (boolish) -> boolish
# If true, enables token-level timestamps.
#
def token_timestamps: () -> (true | false) def token_timestamps: () -> (true | false)
def split_on_word=: (boolish) -> boolish def split_on_word=: (boolish) -> boolish
# If true, split on word rather than on token (when used with max_len).
#
def split_on_word: () -> (true | false) def split_on_word: () -> (true | false)
def initial_prompt=: (_ToS) -> _ToS def initial_prompt=: (_ToS) -> _ToS
# Tokens to provide to the whisper decoder as initial prompt
# these are prepended to any existing text context from a previous call
# use whisper_tokenize() to convert text to tokens.
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
#
def initial_prompt: () -> (String | nil) def initial_prompt: () -> (String | nil)
def diarize=: (boolish) -> boolish def diarize=: (boolish) -> boolish
# If true, enables diarization.
#
def diarize: () -> (true | false) def diarize: () -> (true | false)
def offset=: (Integer) -> Integer def offset=: (Integer) -> Integer
# Start offset in ms.
#
def offset: () -> Integer def offset: () -> Integer
def duration=: (Integer) -> Integer def duration=: (Integer) -> Integer
# Audio duration to process in ms.
#
def duration: () -> Integer def duration: () -> Integer
def max_text_tokens=: (Integer) -> Integer def max_text_tokens=: (Integer) -> Integer
# Max tokens to use from past text as prompt for the decoder.
#
def max_text_tokens: () -> Integer def max_text_tokens: () -> Integer
def temperature=: (Float) -> Float def temperature=: (Float) -> Float
def temperature: () -> Float def temperature: () -> Float
def max_initial_ts=: (Float) -> Float def max_initial_ts=: (Float) -> Float
# See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
#
def max_initial_ts: () -> Float def max_initial_ts: () -> Float
def length_penalty=: (Float) -> Float def length_penalty=: (Float) -> Float
def length_penalty: () -> Float def length_penalty: () -> Float
def temperature_inc=: (Float) -> Float def temperature_inc=: (Float) -> Float
def temperature_inc: () -> Float def temperature_inc: () -> Float
def entropy_thold=: (Float) -> Float def entropy_thold=: (Float) -> Float
# Similar to OpenAI's "compression_ratio_threshold"
#
def entropy_thold: () -> Float def entropy_thold: () -> Float
def logprob_thold=: (Float) -> Float def logprob_thold=: (Float) -> Float
def logprob_thold: () -> Float def logprob_thold: () -> Float
def no_speech_thold=: (Float) -> Float def no_speech_thold=: (Float) -> Float
def no_speech_thold: () -> Float def no_speech_thold: () -> Float
# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (new_segment_callback) -> new_segment_callback def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback: () -> (new_segment_callback | nil) def new_segment_callback: () -> (new_segment_callback | nil)
# Sets user data passed to the last argument of new segment callback.
#
def new_segment_callback_user_data=: (Object) -> Object def new_segment_callback_user_data=: (Object) -> Object
def new_segment_callback_user_data: () -> Object def new_segment_callback_user_data: () -> Object
# Sets progress callback, called on each progress update.
#
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
#
# +progress+ is an Integer between 0 and 100.
#
def progress_callback=: (progress_callback) -> progress_callback def progress_callback=: (progress_callback) -> progress_callback
def progress_callback: () -> (progress_callback | nil) def progress_callback: () -> (progress_callback | nil)
# Sets user data passed to the last argument of progress callback.
#
def progress_callback_user_data=: (Object) -> Object def progress_callback_user_data=: (Object) -> Object
def progress_callback_user_data: () -> Object def progress_callback_user_data: () -> Object
# Sets encoder begin callback, called when the encoder starts.
#
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
def encoder_begin_callback: () -> (encoder_begin_callback | nil)
# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object) -> Object
def encoder_begin_callback_user_data: () -> Object
# Sets abort callback, called to check if the process should be aborted.
#
# params.abort_callback = ->(user_data) {
# # ...
# }
#
#
def abort_callback=: (abort_callback) -> abort_callback def abort_callback=: (abort_callback) -> abort_callback
def abort_callback: () -> (abort_callback | nil) def abort_callback: () -> (abort_callback | nil)
# Sets user data passed to the last argument of abort callback.
#
def abort_callback_user_data=: (Object) -> Object def abort_callback_user_data=: (Object) -> Object
def abort_callback_user_data: () -> Object def abort_callback_user_data: () -> Object
# Hook called on new segment. Yields each Whisper::Segment.
#
# whisper.on_new_segment do |segment|
# # ...
# end
#
def on_new_segment: { (Segment) -> void } -> void def on_new_segment: { (Segment) -> void } -> void
# Hook called on progress update. Yields each progress Integer between 0 and 100.
#
def on_progress: { (Integer progress) -> void } -> void def on_progress: { (Integer progress) -> void } -> void
# Hook called on encoder starts.
#
def on_encoder_begin: { () -> void } -> void
# Call block to determine whether abort or not. Return +true+ when you want to abort.
#
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# end
# end
#
def abort_on: { (Object user_data) -> boolish } -> void def abort_on: { (Object user_data) -> boolish } -> void
end end
@ -384,24 +167,16 @@ module Whisper
def type: () -> String def type: () -> String
class URI class URI
def self.new: (string | ::URI::HTTP) -> instance def self.new: (string | ::URI::HTTP) -> self
def to_path: -> String def to_path: -> String
def clear_cache: -> void def clear_cache: -> void
end end
end end
class Segment class Segment
# Start time in milliseconds.
#
def start_time: () -> Integer def start_time: () -> Integer
# End time in milliseconds.
#
def end_time: () -> Integer def end_time: () -> Integer
# Whether the next segment is predicted as a speaker turn.
def speaker_next_turn?: () -> (true | false) def speaker_next_turn?: () -> (true | false)
def text: () -> String def text: () -> String
def no_speech_prob: () -> Float def no_speech_prob: () -> Float
end end

View File

@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav") AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self class << self
def whisper attr_reader :whisper
return @whisper if @whisper
def startup
@whisper = Whisper::Context.new("base.en") @whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new params = Whisper::Params.new
params.print_timestamps = false params.print_timestamps = false
@ -21,15 +21,4 @@ class TestBase < Test::Unit::TestCase
def whisper def whisper
self.class.whisper self.class.whisper
end end
module BuildOptions
load "ext/options.rb", self
Options.include self
def enable_config(name)
end
def arg_config(name)
end
end
end end

View File

@ -25,7 +25,7 @@ class TestCallback < TestBase
assert start_time >= 0 assert start_time >= 0
assert_kind_of Integer, end_time assert_kind_of Integer, end_time
assert end_time > 0 assert end_time > 0
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) if i_segment == 0 assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
end end
} }
@ -111,48 +111,6 @@ class TestCallback < TestBase
assert_equal 100, last assert_equal 100, last
end end
def test_encoder_begin_callback
i = 0
@params.encoder_begin_callback = ->(context, state, user_data) {
i += 1
}
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_encoder_begin_callback_abort
logs = []
Whisper.log_set -> (level, buffer, user_data) {
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
}, logs
@params.encoder_begin_callback = ->(context, state, user_data) {
return false
}
@whisper.transcribe(@audio, @params)
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end
def test_encoder_begin_callback_user_data
udata = Object.new
@params.encoder_begin_callback_user_data = udata
yielded = nil
@params.encoder_begin_callback = ->(context, state, user_data) {
yielded = user_data
}
@whisper.transcribe(@audio, @params)
assert_same udata, yielded
end
def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_abort_callback def test_abort_callback
i = 0 i = 0
@params.abort_callback = ->(user_data) { @params.abort_callback = ->(user_data) {
@ -187,9 +145,9 @@ class TestCallback < TestBase
def test_abort_on def test_abort_on
do_abort = false do_abort = false
_aborted_from_callback = false aborted_from_callback = false
@params.on_new_segment do |segment| @params.on_new_segment do |segment|
do_abort = true if segment.text.match?(/ask/) do_abort = true if segment.text.match? /ask/
end end
i = 0 i = 0
@params.abort_on do @params.abort_on do

View File

@ -4,7 +4,7 @@ class TestError < TestBase
def test_error def test_error
error = Whisper::Error.new(-2) error = Whisper::Error.new(-2)
assert_equal "failed to compute log mel spectrogram", error.message assert_equal "failed to compute log mel spectrogram", error.message
assert_equal(-2, error.code) assert_equal -2, error.code
end end
def test_unknown_error def test_unknown_error
@ -14,7 +14,7 @@ class TestError < TestBase
def test_non_int_code def test_non_int_code
assert_raise TypeError do assert_raise TypeError do
_error = Whisper::Error.new("non int") error = Whisper::Error.new("non int")
end end
end end
end end

View File

@ -21,26 +21,11 @@ class TestPackage < TestBase
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/) match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
filename = match_data[1] filename = match_data[1]
version = match_data[2] version = match_data[2]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir| Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
assert_installed dir, version assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
end end
end end
private
def assert_installed(dir, version)
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
end
end
def test_build_options
options = BuildOptions::Options.new
assert_empty options.missing_options
unless ENV["CI"]
assert_empty options.extra_options
end
end end
end end

View File

@ -162,7 +162,7 @@ class TestParams < TestBase
end end
def test_length_penalty def test_length_penalty
assert_equal(-1.0, @params.length_penalty) assert_equal -1.0, @params.length_penalty
@params.length_penalty = 0.5 @params.length_penalty = 0.5
assert_equal 0.5, @params.length_penalty assert_equal 0.5, @params.length_penalty
end end
@ -180,9 +180,9 @@ class TestParams < TestBase
end end
def test_logprob_thold def test_logprob_thold
assert_in_delta(-1.0, @params.logprob_thold) assert_in_delta -1.0, @params.logprob_thold
@params.logprob_thold = -0.5 @params.logprob_thold = -0.5
assert_in_delta(-0.5, @params.logprob_thold) assert_in_delta -0.5, @params.logprob_thold
end end
def test_no_speech_thold def test_no_speech_thold

View File

@ -49,13 +49,13 @@ class TestSegment < TestBase
if index == 0 if index == 0
seg = segment seg = segment
assert_equal 0, segment.start_time assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end end
index += 1 index += 1
end end
whisper.transcribe(AUDIO, params) whisper.transcribe(AUDIO, params)
assert_equal 0, seg.start_time assert_equal 0, seg.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
end end
def test_on_new_segment_twice def test_on_new_segment_twice

View File

@ -16,7 +16,7 @@ class TestWhisper < TestBase
params.print_timestamps = false params.print_timestamps = false
@whisper.transcribe(AUDIO, params) {|text| @whisper.transcribe(AUDIO, params) {|text|
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) assert_match /ask not what your country can do for you, ask what you can do for your country/, text
} }
end end
@ -32,7 +32,7 @@ class TestWhisper < TestBase
def test_full_get_segment def test_full_get_segment
segment = whisper.full_get_segment(0) segment = whisper.full_get_segment(0)
assert_equal 0, segment.start_time assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end end
def test_full_get_segment_t0 def test_full_get_segment_t0
@ -59,7 +59,7 @@ class TestWhisper < TestBase
end end
def test_full_get_segment_text def test_full_get_segment_text
assert_match(/ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)) assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end end
def test_full_get_segment_no_speech_prob def test_full_get_segment_no_speech_prob
@ -134,14 +134,14 @@ class TestWhisper < TestBase
@whisper.full(@params, @samples, @samples.length) @whisper.full(@params, @samples, @samples.length)
assert_equal 1, @whisper.full_n_segments assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end end
def test_full_without_length def test_full_without_length
@whisper.full(@params, @samples) @whisper.full(@params, @samples)
assert_equal 1, @whisper.full_n_segments assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end end
def test_full_enumerator def test_full_enumerator
@ -149,7 +149,7 @@ class TestWhisper < TestBase
@whisper.full(@params, samples, @samples.length) @whisper.full(@params, samples, @samples.length)
assert_equal 1, @whisper.full_n_segments assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end end
def test_full_enumerator_without_length def test_full_enumerator_without_length
@ -171,28 +171,26 @@ class TestWhisper < TestBase
@whisper.full(@params, samples) @whisper.full(@params, samples)
assert_equal 1, @whisper.full_n_segments assert_equal 1, @whisper.full_n_segments
assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end end
def test_full_parallel def test_full_parallel
nprocessors = 2 @whisper.full_parallel(@params, @samples, @samples.length, Etc.nprocessors)
@whisper.full_parallel(@params, @samples, @samples.length, nprocessors)
assert_equal nprocessors, @whisper.full_n_segments assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text) assert_match /ask what you can do/i, text
assert_match(/for your country/i, text) assert_match /for your country/i, text
end end
def test_full_parallel_with_memory_view def test_full_parallel_with_memory_view
nprocessors = 2
samples = JFKReader.new(AUDIO) samples = JFKReader.new(AUDIO)
@whisper.full_parallel(@params, samples, nil, nprocessors) @whisper.full_parallel(@params, samples, nil, Etc.nprocessors)
assert_equal nprocessors, @whisper.full_n_segments assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text) assert_match /ask what you can do/i, text
assert_match(/for your country/i, text) assert_match /for your country/i, text
end end
def test_full_parallel_without_length_and_n_processors def test_full_parallel_without_length_and_n_processors
@ -200,18 +198,17 @@ class TestWhisper < TestBase
assert_equal 1, @whisper.full_n_segments assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text) assert_match /ask what you can do/i, text
assert_match(/for your country/i, text) assert_match /for your country/i, text
end end
def test_full_parallel_without_length def test_full_parallel_without_length
nprocessors = 2 @whisper.full_parallel(@params, @samples, nil, Etc.nprocessors)
@whisper.full_parallel(@params, @samples, nil, nprocessors)
assert_equal nprocessors, @whisper.full_n_segments assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text) assert_match /ask what you can do/i, text
assert_match(/for your country/i, text) assert_match /for your country/i, text
end end
def test_full_parallel_without_n_processors def test_full_parallel_without_n_processors
@ -219,8 +216,8 @@ class TestWhisper < TestBase
assert_equal 1, @whisper.full_n_segments assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join text = @whisper.each_segment.collect(&:text).join
assert_match(/ask what you can do/i, text) assert_match /ask what you can do/i, text
assert_match(/for your country/i, text) assert_match /for your country/i, text
end end
end end
end end

View File

@ -3,8 +3,8 @@ require_relative "extsources"
Gem::Specification.new do |s| Gem::Specification.new do |s|
s.name = "whispercpp" s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"] s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.2' s.version = '1.3.1'
s.date = '2025-05-01' s.date = '2024-12-19'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com' s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md'] s.extra_rdoc_files = ['LICENSE', 'README.md']
@ -15,8 +15,7 @@ Gem::Specification.new do |s|
if s.extra_rdoc_files.include?(basename) if s.extra_rdoc_files.include?(basename)
basename basename
else else
file.sub("../..", "ext/sources") file.sub("../..", "ext")
.sub("../javascript", "ext/sources/bindings/javascript")
end end
} }
@ -27,7 +26,7 @@ Gem::Specification.new do |s|
s.required_ruby_version = '>= 3.1.0' s.required_ruby_version = '>= 3.1.0'
#### Documentation and testing. #### Documentation and testing.
s.homepage = 'https://github.com/ggml-org/whisper.cpp' s.homepage = 'https://github.com/ggerganov/whisper.cpp'
s.rdoc_options = ['--main', 'README.md'] s.rdoc_options = ['--main', 'README.md']

View File

@ -41,11 +41,6 @@ COMMON_CMAKE_ARGS=(
-DGGML_OPENMP=${GGML_OPENMP} -DGGML_OPENMP=${GGML_OPENMP}
) )
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
check_required_tool() { check_required_tool() {
local tool=$1 local tool=$1
local install_message=$2 local install_message=$2
@ -250,16 +245,9 @@ combine_static_libraries() {
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a" "${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a" "${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
) )
if [[ "$platform" == "macos" || "$platform" == "ios" ]]; then
echo "Adding libwhisper.coreml library to the build."
libs+=(
"${base_dir}/${build_dir}/src/${release_dir}/libwhisper.coreml.a"
)
fi
# Create temporary directory for processing # Create temporary directory for processing
local temp_dir="${base_dir}/${build_dir}/temp" local temp_dir="${base_dir}/${build_dir}/temp"
echo "Creating temporary directory: ${temp_dir}"
mkdir -p "${temp_dir}" mkdir -p "${temp_dir}"
# Since we have multiple architectures libtool will find object files that do not # Since we have multiple architectures libtool will find object files that do not
@ -271,7 +259,6 @@ combine_static_libraries() {
local archs="" local archs=""
local min_version_flag="" local min_version_flag=""
local install_name="" local install_name=""
local frameworks="-framework Foundation -framework Metal -framework Accelerate"
case "$platform" in case "$platform" in
"ios") "ios")
@ -285,14 +272,12 @@ combine_static_libraries() {
min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}" min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}"
fi fi
install_name="@rpath/whisper.framework/whisper" install_name="@rpath/whisper.framework/whisper"
frameworks+=" -framework CoreML"
;; ;;
"macos") "macos")
sdk="macosx" sdk="macosx"
archs="arm64 x86_64" archs="arm64 x86_64"
min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}" min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}"
install_name="@rpath/whisper.framework/Versions/Current/whisper" install_name="@rpath/whisper.framework/Versions/Current/whisper"
frameworks+=" -framework CoreML"
;; ;;
"visionos") "visionos")
if [[ "$is_simulator" == "true" ]]; then if [[ "$is_simulator" == "true" ]]; then
@ -334,34 +319,27 @@ combine_static_libraries() {
$arch_flags \ $arch_flags \
$min_version_flag \ $min_version_flag \
-Wl,-force_load,"${temp_dir}/combined.a" \ -Wl,-force_load,"${temp_dir}/combined.a" \
$frameworks \ -framework Foundation -framework Metal -framework Accelerate \
-install_name "$install_name" \ -install_name "$install_name" \
-o "${base_dir}/${output_lib}" -o "${base_dir}/${output_lib}"
# Platform-specific post-processing for device builds # Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then if [[ "$is_simulator" == "false" ]]; then
if command -v xcrun vtool &>/dev/null; then if command -v vtool &>/dev/null; then
case "$platform" in case "$platform" in
"ios") "ios")
echo "Marking binary as a framework binary for iOS..." echo "Marking binary as a framework binary for iOS..."
xcrun vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \ vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}" -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;; ;;
"visionos") "visionos")
echo "Marking binary as a framework binary for visionOS..." echo "Marking binary as a framework binary for visionOS..."
if [[ "$MAJOR_VERSION" -gt 16 ]] || [[ "$MAJOR_VERSION" -eq 16 && "$MINOR_VERSION" -gt 2 ]]; then vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
echo "Xcode version greater than 16.2, using visionOS."
VISION_OS_BUILD_VERSION="visionos"
else
echo "Xcode version less than or equal to 16.2, using xros."
VISION_OS_BUILD_VERSION="xros"
fi
xcrun vtool -set-build-version ${VISION_OS_BUILD_VERSION} ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}" -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;; ;;
"tvos") "tvos")
echo "Marking binary as a framework binary for tvOS..." echo "Marking binary as a framework binary for tvOS..."
xcrun vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \ vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}" -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;; ;;
esac esac
@ -421,8 +399,6 @@ cmake -B build-ios-sim -G Xcode \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DWHISPER_COREML="ON" \
-DWHISPER_COREML_ALLOW_FALLBACK="ON" \
-S . -S .
cmake --build build-ios-sim --config Release -- -quiet cmake --build build-ios-sim --config Release -- -quiet
@ -435,8 +411,6 @@ cmake -B build-ios-device -G Xcode \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DWHISPER_COREML="ON" \
-DWHISPER_COREML_ALLOW_FALLBACK="ON" \
-S . -S .
cmake --build build-ios-device --config Release -- -quiet cmake --build build-ios-device --config Release -- -quiet
@ -447,8 +421,6 @@ cmake -B build-macos -G Xcode \
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DWHISPER_COREML="ON" \
-DWHISPER_COREML_ALLOW_FALLBACK="ON" \
-S . -S .
cmake --build build-macos --config Release -- -quiet cmake --build build-macos --config Release -- -quiet
@ -460,8 +432,8 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \ -DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \ -DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \ -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S . -S .
cmake --build build-visionos --config Release -- -quiet cmake --build build-visionos --config Release -- -quiet
@ -473,8 +445,8 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \ -DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \ -DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \ -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S . -S .
cmake --build build-visionos-sim --config Release -- -quiet cmake --build build-visionos-sim --config Release -- -quiet

View File

@ -10,8 +10,6 @@
# # with CUDA support # # with CUDA support
# GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
# #
# # with SYCL support
# GG_BUILD_SYCL=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
if [ -z "$2" ]; then if [ -z "$2" ]; then
echo "usage: $0 <output-dir> <mnt-dir>" echo "usage: $0 <output-dir> <mnt-dir>"
@ -326,9 +324,8 @@ ret=0
for model in "${MODELS[@]}"; do for model in "${MODELS[@]}"; do
test $ret -eq 0 && gg_download_model ${model} test $ret -eq 0 && gg_download_model ${model}
done done
if [ -z ${GG_BUILD_SYCL}]; then
test $ret -eq 0 && gg_run ctest debug test $ret -eq 0 && gg_run ctest debug
fi
test $ret -eq 0 && gg_run ctest release test $ret -eq 0 && gg_run ctest release
test $ret -eq 0 && gg_run bench test $ret -eq 0 && gg_run bench

View File

@ -18,13 +18,6 @@ const whisperParamsMock = {
translate: true, translate: true,
no_timestamps: false, no_timestamps: false,
audio_ctx: 0, audio_ctx: 0,
max_len: 0,
prompt: "",
print_progress: false,
progress_callback: (progress) => {
console.log(`Progress: ${progress}`);
},
max_context: -1
}; };
describe("Run whisper.node", () => { describe("Run whisper.node", () => {

View File

@ -128,67 +128,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
void cb_log_disable(enum ggml_log_level, const char *, void *) {} void cb_log_disable(enum ggml_log_level, const char *, void *) {}
class ProgressWorker : public Napi::AsyncWorker { int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
public:
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
: Napi::AsyncWorker(callback), params(params), env(env) {
// Create thread-safe function
if (!progress_callback.IsEmpty()) {
tsfn = Napi::ThreadSafeFunction::New(
env,
progress_callback,
"Progress Callback",
0,
1
);
}
}
~ProgressWorker() {
if (tsfn) {
// Make sure to release the thread-safe function on destruction
tsfn.Release();
}
}
void Execute() override {
// Use custom run function with progress callback support
run_with_progress(params, result);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
// Progress callback function - using thread-safe function
void OnProgress(int progress) {
if (tsfn) {
// Use thread-safe function to call JavaScript callback
auto callback = [progress](Napi::Env env, Napi::Function jsCallback) {
jsCallback.Call({Napi::Number::New(env, progress)});
};
tsfn.BlockingCall(callback);
}
}
private:
whisper_params params;
std::vector<std::vector<std::string>> result;
Napi::Env env;
Napi::ThreadSafeFunction tsfn;
// Custom run function with progress callback support
int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) { if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL); whisper_log_set(cb_log_disable, NULL);
} }
@ -204,6 +144,7 @@ class ProgressWorker : public Napi::AsyncWorker {
} }
// whisper init // whisper init
struct whisper_context_params cparams = whisper_context_default_params(); struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
@ -214,7 +155,8 @@ class ProgressWorker : public Napi::AsyncWorker {
return 3; return 3;
} }
// If params.pcmf32 provides, set params.fname_inp as "buffer" // if params.pcmf32 is provided, set params.fname_inp to "buffer"
// this is simpler than further modifications in the code
if (!params.pcmf32.empty()) { if (!params.pcmf32.empty()) {
fprintf(stderr, "info: using audio buffer as input\n"); fprintf(stderr, "info: using audio buffer as input\n");
params.fname_inp.clear(); params.fname_inp.clear();
@ -228,7 +170,7 @@ class ProgressWorker : public Napi::AsyncWorker {
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// If params.pcmf32 is empty, read input audio file // read the input audio file if params.pcmf32 is not provided
if (params.pcmf32.empty()) { if (params.pcmf32.empty()) {
if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) { if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str()); fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
@ -238,14 +180,14 @@ class ProgressWorker : public Napi::AsyncWorker {
pcmf32 = params.pcmf32; pcmf32 = params.pcmf32;
} }
// Print system info // print system information
if (!params.no_prints) { if (!params.no_prints) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
} }
// Print processing info // print some info about the processing
if (!params.no_prints) { if (!params.no_prints) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) { if (!whisper_is_multilingual(ctx)) {
@ -266,7 +208,7 @@ class ProgressWorker : public Napi::AsyncWorker {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
// Run inference // run the inference
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@ -299,22 +241,17 @@ class ProgressWorker : public Napi::AsyncWorker {
whisper_print_user_data user_data = { &params, &pcmf32s }; whisper_print_user_data user_data = { &params, &pcmf32s };
// This callback is called for each new segment // this callback is called on each new segment
if (!wparams.print_realtime) { if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback; wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data; wparams.new_segment_callback_user_data = &user_data;
} }
// Set progress callback // example for abort mechanism
wparams.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { // in this example, we do not abort the processing, but we could if the flag is set to true
ProgressWorker* worker = static_cast<ProgressWorker*>(user_data); // the callback is called before every encoder run - if it returns false, the processing is aborted
worker->OnProgress(progress);
};
wparams.progress_callback_user_data = this;
// Abort mechanism example
{ {
static bool is_aborted = false; // Note: this should be atomic to avoid data races static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
@ -346,9 +283,37 @@ class ProgressWorker : public Napi::AsyncWorker {
whisper_free(ctx); whisper_free(ctx);
return 0; return 0;
}
class Worker : public Napi::AsyncWorker {
public:
Worker(Napi::Function& callback, whisper_params params)
: Napi::AsyncWorker(callback), params(params) {}
void Execute() override {
run(params, result);
} }
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
private:
whisper_params params;
std::vector<std::vector<std::string>> result;
}; };
Napi::Value whisper(const Napi::CallbackInfo& info) { Napi::Value whisper(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env(); Napi::Env env = info.Env();
if (info.Length() <= 0 || !info[0].IsObject()) { if (info.Length() <= 0 || !info[0].IsObject()) {
@ -368,29 +333,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>(); bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>(); int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
// Add support for max_context
int32_t max_context = -1;
if (whisper_params.Has("max_context") && whisper_params.Get("max_context").IsNumber()) {
max_context = whisper_params.Get("max_context").As<Napi::Number>();
}
// support prompt
std::string prompt = "";
if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
prompt = whisper_params.Get("prompt").As<Napi::String>();
}
// Add support for print_progress
bool print_progress = false;
if (whisper_params.Has("print_progress")) {
print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
}
// Add support for progress_callback
Napi::Function progress_callback;
if (whisper_params.Has("progress_callback") && whisper_params.Get("progress_callback").IsFunction()) {
progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
}
Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
std::vector<float> pcmf32_vec; std::vector<float> pcmf32_vec;
if (pcmf32Value.IsTypedArray()) { if (pcmf32Value.IsTypedArray()) {
@ -413,13 +355,9 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.pcmf32 = pcmf32_vec; params.pcmf32 = pcmf32_vec;
params.comma_in_time = comma_in_time; params.comma_in_time = comma_in_time;
params.max_len = max_len; params.max_len = max_len;
params.max_context = max_context;
params.print_progress = print_progress;
params.prompt = prompt;
Napi::Function callback = info[1].As<Napi::Function>(); Napi::Function callback = info[1].As<Napi::Function>();
// Create a new Worker class with progress callback support Worker* worker = new Worker(callback, params);
ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env);
worker->Queue(); worker->Queue();
return env.Undefined(); return env.Undefined();
} }

View File

@ -19,9 +19,6 @@ const whisperParams = {
no_timestamps: false, no_timestamps: false,
audio_ctx: 0, audio_ctx: 0,
max_len: 0, max_len: 0,
progress_callback: (progress) => {
console.log(`progress: ${progress}%`);
}
}; };
const arguments = process.argv.slice(2); const arguments = process.argv.slice(2);

View File

@ -2,7 +2,7 @@
Benchmark the performance of whisper.cpp in the browser using WebAssembly Benchmark the performance of whisper.cpp in the browser using WebAssembly
Link: https://ggerganov.github.io/whisper.cpp/bench.wasm Link: https://whisper.ggerganov.com/bench/
Terminal version: [examples/bench](/examples/bench) Terminal version: [examples/bench](/examples/bench)
@ -15,17 +15,7 @@ cd whisper.cpp
mkdir build-em && cd build-em mkdir build-em && cd build-em
emcmake cmake .. emcmake cmake ..
make -j make -j
```
The example can then be started by running a local HTTP server:
```console
python3 examples/server.py
```
And then opening a browser to the following URL:
http://localhost:8000/bench.wasm
To run the example in a different server, you need to copy the following files
to the server's HTTP path:
```
# copy the produced page to your HTTP path # copy the produced page to your HTTP path
cp bin/bench.wasm/* /path/to/html/ cp bin/bench.wasm/* /path/to/html/
cp bin/libbench.worker.js /path/to/html/ cp bin/libbench.worker.js /path/to/html/

View File

@ -24,8 +24,6 @@
overflow-x: scroll; overflow-x: scroll;
} }
</style> </style>
<script src="../coi-serviceworker.js"></script>
<link rel="icon" href="data:,">
</head> </head>
<body> <body>
<div id="main-container"> <div id="main-container">
@ -38,10 +36,11 @@
<br><br> <br><br>
<b>More examples:</b> <b>More examples:</b>
<a href="../">main</a> | <a href="https://whisper.ggerganov.com/">main</a> |
<a href="../bench.wasm/">bench</a> | <a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="../stream.wasm">stream</a> | <a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="../command.wasm/">command</a> | <a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br> <br><br>

View File

@ -4,7 +4,7 @@ A very basic tool for benchmarking the inference performance on your device. The
the transformer on some random audio data and records the execution time. This way we can have an objective comparison the transformer on some random audio data and records the execution time. This way we can have an objective comparison
of the performance of the model for various setups. of the performance of the model for various setups.
Benchmark results are tracked in the following Github issue: https://github.com/ggml-org/whisper.cpp/issues/89 Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
```bash ```bash
# run the bench too on the small.en model using 4 threads # run the bench too on the small.en model using 4 threads
@ -40,7 +40,7 @@ system_info: n_threads = 4 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WA
If you wish, you can submit these results here: If you wish, you can submit these results here:
https://github.com/ggml-org/whisper.cpp/issues/89 https://github.com/ggerganov/whisper.cpp/issues/89
Please include the following information: Please include the following information:

View File

@ -6,8 +6,7 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
``` ```
./build/bin/whisper-cli -h ./build/bin/whisper-cli -h
usage: ./build/bin/whisper-cli [options] file0 file1 ... usage: ./build-pkg/bin/whisper-cli [options] file0.wav file1.wav ...
supported audio formats: flac, mp3, ogg, wav
options: options:
-h, --help [default] show this help message and exit -h, --help [default] show this help message and exit
@ -25,7 +24,6 @@ options:
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-nth N, --no-speech-thold N [0.60 ] no speech threshold
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1 -tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1 -tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel) -debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
@ -52,13 +50,12 @@ options:
-dl, --detect-language [false ] exit after automatically detecting language -dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens) --prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input audio file path -f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps -dtw MODEL --dtw MODEL [ ] compute token-level timestamps
-ls, --log-score [false ] log best decoder scores of tokens -ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU -ng, --no-gpu [false ] disable GPU
-fa, --flash-attn [false ] flash attention -fa, --flash-attn [false ] flash attention
-sns, --suppress-nst [false ] suppress non-speech tokens
--suppress-regex REGEX [ ] regular expression matching tokens to suppress --suppress-regex REGEX [ ] regular expression matching tokens to suppress
--grammar GRAMMAR [ ] GBNF grammar to guide decoding --grammar GRAMMAR [ ] GBNF grammar to guide decoding
--grammar-rule RULE [ ] top-level GBNF grammar rule name --grammar-rule RULE [ ] top-level GBNF grammar rule name

View File

@ -13,12 +13,14 @@
#include <cstring> #include <cstring>
#if defined(_WIN32) #if defined(_WIN32)
#ifndef NOMINMAX
#define NOMINMAX #define NOMINMAX
#endif
#include <windows.h> #include <windows.h>
#endif #endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// helper function to replace substrings // helper function to replace substrings
static void replace_all(std::string & s, const std::string & search, const std::string & replace) { static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) { for (size_t pos = 0; ; pos += replace.length()) {
@ -375,7 +377,15 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
} }
} }
static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
@ -390,9 +400,19 @@ static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const
fout << speaker << text << "\n"; fout << speaker << text << "\n";
} }
return true;
} }
static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "WEBVTT\n\n"; fout << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
@ -412,9 +432,19 @@ static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << speaker << text << "\n\n"; fout << speaker << text << "\n\n";
} }
return true;
} }
static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
@ -431,6 +461,8 @@ static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
fout << speaker << text << "\n\n"; fout << speaker << text << "\n\n";
} }
return true;
} }
static char * escape_double_quotes_and_backslashes(const char * str) { static char * escape_double_quotes_and_backslashes(const char * str) {
@ -496,7 +528,15 @@ static char * escape_double_quotes_in_csv(const char * str) {
return escaped; return escaped;
} }
static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,"; fout << "start,end,";
if (params.diarize && pcmf32s.size() == 2) if (params.diarize && pcmf32s.size() == 2)
@ -519,9 +559,14 @@ static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const
} }
fout << "\"" << text_escaped << "\"\n"; fout << "\"" << text_escaped << "\"\n";
} }
return true;
} }
static void output_score(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) { static bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
// fprintf(stderr,"segments: %d\n",n_segments); // fprintf(stderr,"segments: %d\n",n_segments);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
@ -534,14 +579,16 @@ static void output_score(struct whisper_context * ctx, std::ofstream & fout, con
// fprintf(stderr,"token: %s %f\n",token,probability); // fprintf(stderr,"token: %s %f\n",token,probability);
} }
} }
return true;
} }
static void output_json( static bool output_json(
struct whisper_context * ctx, struct whisper_context * ctx,
std::ofstream & fout, const char * fname,
const whisper_params & params, const whisper_params & params,
std::vector<std::vector<float>> pcmf32s) { std::vector<std::vector<float>> pcmf32s,
const bool full = params.output_jsn_full; bool full) {
std::ofstream fout(fname);
int indent = 0; int indent = 0;
auto doindent = [&]() { auto doindent = [&]() {
@ -621,6 +668,12 @@ static void output_json(
end_obj(end); end_obj(end);
}; };
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj(nullptr); start_obj(nullptr);
value_s("systeminfo", whisper_print_system_info(), false); value_s("systeminfo", whisper_print_system_info(), false);
start_obj("model"); start_obj("model");
@ -694,12 +747,17 @@ static void output_json(
end_arr(true); end_arr(true);
end_obj(true); end_obj(true);
return true;
} }
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s, const char * fname_inp, float t_sec, const char * fname_out) { static bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
static const char * font = params.font_path.c_str(); static const char * font = params.font_path.c_str();
std::ifstream fin(font); std::ifstream fin(font);
@ -815,12 +873,20 @@ static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const
fout.close(); fout.close();
fprintf(stderr, "# %s: run 'source %s' to generate karaoke video\n", __func__, fname_out); fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
return true; return true;
} }
static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "[by:whisper.cpp]\n"; fout << "[by:whisper.cpp]\n";
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
@ -848,6 +914,8 @@ static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n"; fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
} }
return true;
} }
@ -996,55 +1064,8 @@ int main(int argc, char ** argv) {
} }
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto & fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
struct fout_factory { const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::string fname_out;
const size_t basename_length;
const bool is_stdout;
bool used_stdout;
decltype(whisper_print_segment_callback) * const print_segment_callback;
std::ofstream fout;
fout_factory (const std::string & fname_out_, const std::string & fname_inp, whisper_params & params) :
fname_out{!fname_out_.empty() ? fname_out_ : fname_inp},
basename_length{fname_out.size()},
is_stdout{fname_out == "-"},
used_stdout{},
print_segment_callback{is_stdout ? nullptr : whisper_print_segment_callback} {
if (!print_segment_callback) {
params.print_progress = false;
}
}
bool open(const char * ext, const char * function) {
if (is_stdout) {
if (used_stdout) {
fprintf(stderr, "warning: Not appending multiple file formats to stdout\n");
return false;
}
used_stdout = true;
#ifdef _WIN32
fout = std::ofstream{"CON"};
#else
fout = std::ofstream{"/dev/stdout"};
#endif
// Not using fprintf stderr here because it might equal stdout
// Also assuming /dev is mounted
return true;
}
fname_out.resize(basename_length);
fname_out += ext;
fout = std::ofstream{fname_out};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", function, fname_out.c_str());
return true;
}
} fout_factory{f < (int) params.fname_out.size() ? params.fname_out[f] : "", fname_inp, params};
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
@ -1149,7 +1170,7 @@ int main(int argc, char ** argv) {
// this callback is called on each new segment // this callback is called on each new segment
if (!wparams.print_realtime) { if (!wparams.print_realtime) {
wparams.new_segment_callback = fout_factory.print_segment_callback; wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data; wparams.new_segment_callback_user_data = &user_data;
} }
@ -1191,26 +1212,54 @@ int main(int argc, char ** argv) {
// output stuff // output stuff
{ {
// macros to stringify function name printf("\n");
#define output_func(func, ext, param, ...) if (param && fout_factory.open(ext, #func)) {\
func(ctx, fout_factory.fout, params, __VA_ARGS__); \
}
#define output_ext(ext, ...) output_func(output_##ext, "." #ext, params.output_##ext, __VA_ARGS__)
output_ext(txt, pcmf32s); // output to text file
output_ext(vtt, pcmf32s); if (params.output_txt) {
output_ext(srt, pcmf32s); const auto fname_txt = fname_out + ".txt";
output_ext(wts, pcmf32s, fname_inp.c_str(), float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, fout_factory.fname_out.c_str()); output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
output_ext(csv, pcmf32s); }
output_func(output_json, ".json", params.output_jsn, pcmf32s);
output_ext(lrc, pcmf32s);
output_func(output_score, ".score.txt", params.log_score, pcmf32s);
#undef output_ext // output to VTT file
#undef output_func if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
}
if (fout_factory.is_stdout && !fout_factory.used_stdout) { // output to SRT file
fprintf(stderr, "warning: '--output-file -' used without any other '--output-*'"); if (params.output_srt) {
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
}
// output to score file
if (params.log_score) {
const auto fname_score = fname_out + ".score.txt";
output_score(ctx, fname_score.c_str(), params, pcmf32s);
} }
} }
} }

View File

@ -1,146 +0,0 @@
/*! coi-serviceworker v0.1.7 - Guido Zuidhof and contributors, licensed under MIT */
let coepCredentialless = false;
if (typeof window === 'undefined') {
self.addEventListener("install", () => self.skipWaiting());
self.addEventListener("activate", (event) => event.waitUntil(self.clients.claim()));
self.addEventListener("message", (ev) => {
if (!ev.data) {
return;
} else if (ev.data.type === "deregister") {
self.registration
.unregister()
.then(() => {
return self.clients.matchAll();
})
.then(clients => {
clients.forEach((client) => client.navigate(client.url));
});
} else if (ev.data.type === "coepCredentialless") {
coepCredentialless = ev.data.value;
}
});
self.addEventListener("fetch", function (event) {
const r = event.request;
if (r.cache === "only-if-cached" && r.mode !== "same-origin") {
return;
}
const request = (coepCredentialless && r.mode === "no-cors")
? new Request(r, {
credentials: "omit",
})
: r;
event.respondWith(
fetch(request)
.then((response) => {
if (response.status === 0) {
return response;
}
const newHeaders = new Headers(response.headers);
newHeaders.set("Cross-Origin-Embedder-Policy",
coepCredentialless ? "credentialless" : "require-corp"
);
if (!coepCredentialless) {
newHeaders.set("Cross-Origin-Resource-Policy", "cross-origin");
}
newHeaders.set("Cross-Origin-Opener-Policy", "same-origin");
return new Response(response.body, {
status: response.status,
statusText: response.statusText,
headers: newHeaders,
});
})
.catch((e) => console.error(e))
);
});
} else {
(() => {
const reloadedBySelf = window.sessionStorage.getItem("coiReloadedBySelf");
window.sessionStorage.removeItem("coiReloadedBySelf");
const coepDegrading = (reloadedBySelf == "coepdegrade");
// You can customize the behavior of this script through a global `coi` variable.
const coi = {
shouldRegister: () => !reloadedBySelf,
shouldDeregister: () => false,
coepCredentialless: () => true,
coepDegrade: () => true,
doReload: () => window.location.reload(),
quiet: false,
...window.coi
};
const n = navigator;
const controlling = n.serviceWorker && n.serviceWorker.controller;
// Record the failure if the page is served by serviceWorker.
if (controlling && !window.crossOriginIsolated) {
window.sessionStorage.setItem("coiCoepHasFailed", "true");
}
const coepHasFailed = window.sessionStorage.getItem("coiCoepHasFailed");
if (controlling) {
// Reload only on the first failure.
const reloadToDegrade = coi.coepDegrade() && !(
coepDegrading || window.crossOriginIsolated
);
n.serviceWorker.controller.postMessage({
type: "coepCredentialless",
value: (reloadToDegrade || coepHasFailed && coi.coepDegrade())
? false
: coi.coepCredentialless(),
});
if (reloadToDegrade) {
!coi.quiet && console.log("Reloading page to degrade COEP.");
window.sessionStorage.setItem("coiReloadedBySelf", "coepdegrade");
coi.doReload("coepdegrade");
}
if (coi.shouldDeregister()) {
n.serviceWorker.controller.postMessage({ type: "deregister" });
}
}
// If we're already coi: do nothing. Perhaps it's due to this script doing its job, or COOP/COEP are
// already set from the origin server. Also if the browser has no notion of crossOriginIsolated, just give up here.
if (window.crossOriginIsolated !== false || !coi.shouldRegister()) return;
if (!window.isSecureContext) {
!coi.quiet && console.log("COOP/COEP Service Worker not registered, a secure context is required.");
return;
}
// In some environments (e.g. Firefox private mode) this won't be available
if (!n.serviceWorker) {
!coi.quiet && console.error("COOP/COEP Service Worker not registered, perhaps due to private mode.");
return;
}
n.serviceWorker.register(window.document.currentScript.src).then(
(registration) => {
!coi.quiet && console.log("COOP/COEP Service Worker registered", registration.scope);
registration.addEventListener("updatefound", () => {
!coi.quiet && console.log("Reloading page to make use of updated COOP/COEP Service Worker.");
window.sessionStorage.setItem("coiReloadedBySelf", "updatefound");
coi.doReload();
});
// If the registration is active, but it's not controlling the page
if (registration.active && !n.serviceWorker.controller) {
!coi.quiet && console.log("Reloading page to make use of COOP/COEP Service Worker.");
window.sessionStorage.setItem("coiReloadedBySelf", "notcontrolling");
coi.doReload();
}
},
(err) => {
!coi.quiet && console.error("COOP/COEP Service Worker failed to register:", err);
}
);
})();
}

View File

@ -3,7 +3,7 @@
This is a basic Voice Assistant example that accepts voice commands from the microphone. This is a basic Voice Assistant example that accepts voice commands from the microphone.
It runs in fully in the browser via WebAseembly. It runs in fully in the browser via WebAseembly.
Online demo: https://ggerganov.github.io/whisper.cpp/command.wasm Online demo: https://whisper.ggerganov.com/command/
Terminal version: [examples/command](/examples/command) Terminal version: [examples/command](/examples/command)
@ -15,18 +15,9 @@ git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp cd whisper.cpp
mkdir build-em && cd build-em mkdir build-em && cd build-em
emcmake cmake .. emcmake cmake ..
make -j libcommand make -j
```
The example can then be started by running a local HTTP server:
```console
python3 examples/server.py
```
And then opening a browser to the following URL:
http://localhost:8000/command.wasm/
To run the example in a different server, you need to copy the following files # copy the produced page to your HTTP path
to the server's HTTP path:
```
cp bin/command.wasm/* /path/to/html/ cp bin/command.wasm/* /path/to/html/
cp bin/libcommand.worker.js /path/to/html/ cp bin/libcommand.worker.js /path/to/html/
``` ```

View File

@ -24,8 +24,6 @@
overflow-x: scroll; overflow-x: scroll;
} }
</style> </style>
<script src="../coi-serviceworker.js"></script>
<link rel="icon" href="data:,">
</head> </head>
<body> <body>
<div id="main-container"> <div id="main-container">
@ -38,10 +36,11 @@
<br><br> <br><br>
<b>More examples:</b> <b>More examples:</b>
<a href="../">main</a> | <a href="https://whisper.ggerganov.com/">main</a> |
<a href="../bench.wasm/">bench</a> | <a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="../stream.wasm">stream</a> | <a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="../command.wasm/">command</a> | <a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br> <br><br>

View File

@ -3,7 +3,7 @@
// Speak short text commands to the microphone. // Speak short text commands to the microphone.
// This program will detect your voice command and convert them to text. // This program will detect your voice command and convert them to text.
// //
// ref: https://github.com/ggml-org/whisper.cpp/issues/171 // ref: https://github.com/ggerganov/whisper.cpp/issues/171
// //
#include "common-sdl.h" #include "common-sdl.h"

View File

@ -26,6 +26,10 @@
#define MINIAUDIO_IMPLEMENTATION #define MINIAUDIO_IMPLEMENTATION
#include "miniaudio.h" #include "miniaudio.h"
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#ifdef _WIN32 #ifdef _WIN32
#include <fcntl.h> #include <fcntl.h>
#include <io.h> #include <io.h>

View File

@ -10,6 +10,10 @@
#include <regex> #include <regex>
#include <sstream> #include <sstream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Function to check if the next argument exists // Function to check if the next argument exists
static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
if (i + 1 < argc && argv[i + 1][0] != '-') { if (i + 1 < argc && argv[i + 1][0] != '-') {
@ -243,6 +247,17 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result; return result;
} }
std::string convert_to_utf8(const std::wstring & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(input);
}
std::wstring convert_to_wstring(const std::string & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.from_bytes(input);
}
void gpt_split_words(std::string str, std::vector<std::string>& words) { void gpt_split_words(std::string str, std::vector<std::string>& words) {
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
const std::regex re(pattern); const std::regex re(pattern);

View File

@ -1,6 +1,4 @@
add_executable(main ./deprecation-warning.cpp) add_executable(main ./deprecation-warning.cpp)
add_executable(bench ./deprecation-warning.cpp) add_executable(bench ./deprecation-warning.cpp)
if (WHISPER_SDL2) add_executable(stream ./deprecation-warning.cpp)
add_executable(stream ./deprecation-warning.cpp) add_executable(command ./deprecation-warning.cpp)
add_executable(command ./deprecation-warning.cpp)
endif()

View File

@ -194,7 +194,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
AVIOContext *avio_ctx; AVIOContext *avio_ctx;
AVStream *stream; AVStream *stream;
AVCodecContext *codec; AVCodecContext *codec;
AVPacket *packet; AVPacket packet;
AVFrame *frame; AVFrame *frame;
struct SwrContext *swr; struct SwrContext *swr;
u8 *avio_ctx_buffer; u8 *avio_ctx_buffer;
@ -249,20 +249,6 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* prepare resampler */ /* prepare resampler */
swr = swr_alloc(); swr = swr_alloc();
#if LIBAVCODEC_VERSION_MAJOR > 60
AVChannelLayout in_ch_layout = codec->ch_layout;
AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO;
/* Set the source audio layout as-is */
av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0);
av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
/* Convert it into 16khz Mono */
av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0);
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
#else
av_opt_set_int(swr, "in_channel_count", codec->channels, 0); av_opt_set_int(swr, "in_channel_count", codec->channels, 0);
av_opt_set_int(swr, "out_channel_count", 1, 0); av_opt_set_int(swr, "out_channel_count", 1, 0);
av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0);
@ -271,7 +257,6 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
#endif
swr_init(swr); swr_init(swr);
if (!swr_is_initialized(swr)) { if (!swr_is_initialized(swr)) {
@ -279,11 +264,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
return -1; return -1;
} }
packet=av_packet_alloc(); av_init_packet(&packet);
if (!packet) {
LOG("Error allocating the packet\n");
return -1;
}
frame = av_frame_alloc(); frame = av_frame_alloc();
if (!frame) { if (!frame) {
LOG("Error allocating the frame\n"); LOG("Error allocating the frame\n");
@ -293,8 +274,8 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* iterate through frames */ /* iterate through frames */
*data = NULL; *data = NULL;
*size = 0; *size = 0;
while (av_read_frame(fmt_ctx, packet) >= 0) { while (av_read_frame(fmt_ctx, &packet) >= 0) {
avcodec_send_packet(codec, packet); avcodec_send_packet(codec, &packet);
err = avcodec_receive_frame(codec, frame); err = avcodec_receive_frame(codec, frame);
if (err == AVERROR(EAGAIN)) if (err == AVERROR(EAGAIN))
@ -305,11 +286,10 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* Flush any remaining conversion buffers... */ /* Flush any remaining conversion buffers... */
convert_frame(swr, codec, frame, data, size, true); convert_frame(swr, codec, frame, data, size, true);
av_packet_free(&packet);
av_frame_free(&frame); av_frame_free(&frame);
swr_free(&swr); swr_free(&swr);
//avio_context_free(); // todo? //avio_context_free(); // todo?
avcodec_free_context(&codec); avcodec_close(codec);
avformat_close_input(&fmt_ctx); avformat_close_input(&fmt_ctx);
avformat_free_context(fmt_ctx); avformat_free_context(fmt_ctx);

View File

@ -2,7 +2,7 @@
# #
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals # Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
# Idea by @semiformal-net # Idea by @semiformal-net
# ref: https://github.com/ggml-org/whisper.cpp/issues/185 # ref: https://github.com/ggerganov/whisper.cpp/issues/185
# #
set -eo pipefail set -eo pipefail

View File

@ -1,115 +0,0 @@
import http.server
import socketserver
import os
import sys
from pathlib import Path
import urllib.parse
SCRIPT_DIR = Path(__file__).parent.absolute()
DIRECTORY = os.path.join(SCRIPT_DIR, "../build-em/bin")
DIRECTORY = os.path.abspath(DIRECTORY)
# The context root we want for all applications
CONTEXT_ROOT = "/whisper.cpp"
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, directory=DIRECTORY, **kwargs)
def do_GET(self):
# Redirect root to the context root
if self.path == '/':
self.send_response(302)
self.send_header('Location', CONTEXT_ROOT + '/')
self.end_headers()
return
# Handle requests under the context root
if self.path.startswith(CONTEXT_ROOT):
# Remove the context root prefix to get the actual path
actual_path = self.path[len(CONTEXT_ROOT):]
if not actual_path:
self.send_response(302)
self.send_header('Location', CONTEXT_ROOT + '/')
self.end_headers()
return
if '.worker.js' in actual_path:
worker_file = os.path.basename(actual_path)
worker_path = os.path.join(DIRECTORY, worker_file)
if os.path.exists(worker_path):
print(f"Found worker file: {worker_path}")
self.path = '/' + worker_file
else:
print(f"Worker file not found: {worker_path}")
elif actual_path == '/':
self.path = '/whisper.wasm/index.html'
elif actual_path.startswith('/bench.wasm/') or actual_path.startswith('/command.wasm/') or actual_path.startswith('/stream.wasm/'):
# Keep the path as is, just remove the context root
self.path = actual_path
# For all other paths under the context root
else:
# Check if this is a request to a file in whisper.wasm
potential_file = os.path.join(DIRECTORY, 'whisper.wasm', actual_path.lstrip('/'))
if os.path.exists(potential_file) and not os.path.isdir(potential_file):
self.path = '/whisper.wasm' + actual_path
else:
# Try to resolve the file from the base directory
potential_file = os.path.join(DIRECTORY, actual_path.lstrip('/'))
if os.path.exists(potential_file):
self.path = actual_path
# For direct requests to worker files (without context root as these
# are in the build-em/bin directory
elif '.worker.js' in self.path:
worker_file = os.path.basename(self.path)
worker_path = os.path.join(DIRECTORY, worker_file)
if os.path.exists(worker_path):
self.path = '/' + worker_file
# Handle coi-serviceworker.js separately
if 'coi-serviceworker.js' in self.path:
worker_file = "coi-serviceworker.js"
worker_path = os.path.join(SCRIPT_DIR, worker_file)
if os.path.exists(worker_path):
self.send_response(200)
self.send_header('Content-type', 'application/javascript')
self.end_headers()
with open(worker_path, 'rb') as file:
self.wfile.write(file.read())
return
else:
print(f"Warning: Could not find {worker_path}")
return super().do_GET()
def end_headers(self):
# Add required headers for SharedArrayBuffer
self.send_header("Cross-Origin-Opener-Policy", "same-origin")
self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
self.send_header("Access-Control-Allow-Origin", "*")
super().end_headers()
PORT = 8000
# Enable address reuse
class CustomServer(socketserver.TCPServer):
allow_reuse_address = True
try:
with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd:
print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
print(f"Application context root: http://localhost:{PORT}{CONTEXT_ROOT}/")
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nServer stopped.")
# Force complete exit
sys.exit(0)
except OSError as e:
print(f"Error: {e}")
sys.exit(1)

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,10 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
using namespace httplib; using namespace httplib;
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -75,7 +79,6 @@ struct whisper_params {
bool use_gpu = true; bool use_gpu = true;
bool flash_attn = false; bool flash_attn = false;
bool suppress_nst = false; bool suppress_nst = false;
bool no_context = false;
std::string language = "en"; std::string language = "en";
std::string prompt = ""; std::string prompt = "";
@ -137,8 +140,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -185,7 +186,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; }
// server params // server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
@ -506,10 +506,6 @@ void get_req_parameters(const Request & req, whisper_params & params)
{ {
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content); params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
} }
if (req.has_file("no_context"))
{
params.no_context = parse_str_to_bool(req.get_file_value("no_context").content);
}
} }
} // namespace } // namespace
@ -822,7 +818,6 @@ int main(int argc, char ** argv) {
wparams.no_timestamps = params.no_timestamps; wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
wparams.no_context = params.no_context;
wparams.suppress_nst = params.suppress_nst; wparams.suppress_nst = params.suppress_nst;
@ -839,25 +834,33 @@ int main(int argc, char ** argv) {
wparams.progress_callback_user_data = &user_data; wparams.progress_callback_user_data = &user_data;
} }
// tell whisper to abort if the HTTP connection closed // examples for abort mechanism
wparams.abort_callback = [](void *user_data) { // in examples below, we do not abort the processing, but we could if the flag is set to true
// user_data is a pointer to our Request
auto req_ptr = static_cast<const httplib::Request*>(user_data); // the callback is called before every encoder run - if it returns false, the processing is aborted
return req_ptr->is_connection_closed(); {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
}; };
wparams.abort_callback_user_data = (void*)&req; wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
// handle failure or early abort
if (req.is_connection_closed()) {
// log client disconnect
fprintf(stderr, "client disconnected, aborted processing\n");
res.status = 499; // Client Closed Request (nginx convention)
res.set_content("{\"error\":\"client disconnected\"}", "application/json");
return;
}
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
res.status = 500; // Internal Server Error
const std::string error_resp = "{\"error\":\"failed to process audio\"}"; const std::string error_resp = "{\"error\":\"failed to process audio\"}";
res.set_content(error_resp, "application/json"); res.set_content(error_resp, "application/json");
return; return;
@ -916,25 +919,13 @@ int main(int argc, char ** argv) {
} else if (params.response_format == vjson_format) { } else if (params.response_format == vjson_format) {
/* try to match openai/whisper's Python format */ /* try to match openai/whisper's Python format */
std::string results = output_str(ctx, params, pcmf32s); std::string results = output_str(ctx, params, pcmf32s);
// Get language probabilities
std::vector<float> lang_probs(whisper_lang_max_id() + 1, 0.0f);
const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data());
json jres = json{ json jres = json{
{"task", params.translate ? "translate" : "transcribe"}, {"task", params.translate ? "translate" : "transcribe"},
{"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))},
{"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE}, {"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE},
{"text", results}, {"text", results},
{"segments", json::array()}, {"segments", json::array()}
{"detected_language", whisper_lang_str_full(detected_lang_id)},
{"detected_language_probability", lang_probs[detected_lang_id]},
{"language_probabilities", json::object()}
}; };
// Add all language probabilities
for (int i = 0; i <= whisper_lang_max_id(); ++i) {
if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities
jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i];
}
}
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) for (int i = 0; i < n_segments; ++i)
{ {
@ -1033,11 +1024,6 @@ int main(int argc, char ** argv) {
// check if the model is in the file system // check if the model is in the file system
}); });
svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json");
});
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) { svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s"; const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ]; char buf[BUFSIZ];

View File

@ -13,17 +13,7 @@ cd whisper.cpp
mkdir build-em && cd build-em mkdir build-em && cd build-em
emcmake cmake .. emcmake cmake ..
make -j make -j
```
The example can then be started by running a local HTTP server:
```console
python3 examples/server.py
```
And then opening a browser to the following URL:
http://localhost:8000/stream.wasm
To run the example in a different server, you need to copy the following files
to the server's HTTP path:
```
# copy the produced page to your HTTP path # copy the produced page to your HTTP path
cp bin/stream.wasm/* /path/to/html/ cp bin/stream.wasm/* /path/to/html/
cp bin/libstream.worker.js /path/to/html/ cp bin/libstream.worker.js /path/to/html/

View File

@ -24,8 +24,6 @@
overflow-x: scroll; overflow-x: scroll;
} }
</style> </style>
<script src="../coi-serviceworker.js"></script>
<link rel="icon" href="data:,">
</head> </head>
<body> <body>
<div id="main-container"> <div id="main-container">
@ -38,10 +36,11 @@
<br><br> <br><br>
<b>More examples:</b> <b>More examples:</b>
<a href="../">main</a> | <a href="https://whisper.ggerganov.com/">main</a> |
<a href="../bench.wasm/">bench</a> | <a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="../stream.wasm">stream</a> | <a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="../command.wasm/">command</a> | <a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br> <br><br>

View File

@ -12,12 +12,9 @@ if (WHISPER_SDL2)
llama-context.cpp llama-context.cpp
llama-cparams.cpp llama-cparams.cpp
llama-grammar.cpp llama-grammar.cpp
llama-graph.cpp
llama-hparams.cpp llama-hparams.cpp
llama-impl.cpp llama-impl.cpp
llama-io.cpp
llama-kv-cache.cpp llama-kv-cache.cpp
llama-memory.cpp
llama-mmap.cpp llama-mmap.cpp
llama-model-loader.cpp llama-model-loader.cpp
llama-model.cpp llama-model.cpp

View File

@ -4,13 +4,14 @@
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include <algorithm>
#include <map> #include <map>
#include <cassert> #include <cassert>
#include <stdexcept> #include <stdexcept>
// vec // vec
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr; return nullptr;
} }
@ -18,7 +19,7 @@ ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
return tensors[il]; return tensors[il];
} }
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const { struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il); ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) { if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir); cur = ggml_add(ctx, cur, layer_dir);
@ -39,7 +40,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
@ -90,7 +91,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return true; return true;
} }
bool llama_adapter_cvec::apply( int32_t llama_adapter_cvec::apply(
const llama_model & model, const llama_model & model,
const float * data, const float * data,
size_t len, size_t len,
@ -103,17 +104,17 @@ bool llama_adapter_cvec::apply(
// disable the current control vector (but leave allocated for later) // disable the current control vector (but leave allocated for later)
layer_start = -1; layer_start = -1;
layer_end = -1; layer_end = -1;
return true; return 0;
} }
if (n_embd != (int) hparams.n_embd) { if (n_embd != (int) hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return false; return 1;
} }
if (tensors.empty()) { if (tensors.empty()) {
if (!init(model)) { if (!init(model)) {
return false; return 1;
} }
} }
@ -129,12 +130,12 @@ bool llama_adapter_cvec::apply(
} }
} }
return true; return 0;
} }
// lora // lora
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
const std::string name(w->name); const std::string name(w->name);
const auto pos = ab_map.find(name); const auto pos = ab_map.find(name);
@ -145,11 +146,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
return nullptr; return nullptr;
} }
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init; ggml_context * ctx_init;
gguf_init_params meta_gguf_params = { struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true, /* .no_alloc = */ true,
/* .ctx = */ &ctx_init, /* .ctx = */ &ctx_init,
}; };
@ -200,7 +201,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
// add a new context // add a new context
ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(), /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
@ -247,26 +248,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
} }
} }
// get extra buffer types of the CPU
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
std::vector<ggml_backend_buffer_type_t> buft_extra;
{
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_extra.emplace_back(*extra_bufts);
++extra_bufts;
}
}
}
// add tensors // add tensors
for (auto & it : ab_map) { for (auto & it : ab_map) {
const std::string & name = it.first; const std::string & name = it.first;
@ -283,23 +264,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
} }
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer); struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
for (auto & ex : buft_extra) {
if (ex == buft) {
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
buft = ggml_backend_dev_buffer_type(cpu_dev);
break;
}
}
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
ggml_context * dev_ctx = ctx_for_buft(buft);
// validate tensor shape // validate tensor shape
if (is_token_embd) { if (is_token_embd) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@ -316,8 +281,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
} }
// save tensor to adapter // save tensor to adapter
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name); ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@ -343,7 +308,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
{ {
llama_file gguf_file(path_lora, "rb"); llama_file gguf_file(path_lora, "rb");
std::vector<uint8_t> read_buf; std::vector<uint8_t> read_buf;
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) { auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name)); size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
size_t size = ggml_nbytes(orig); size_t size = ggml_nbytes(orig);
read_buf.resize(size); read_buf.resize(size);
@ -362,8 +327,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
} }
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
llama_adapter_lora * adapter = new llama_adapter_lora(); struct llama_adapter_lora * adapter = new llama_adapter_lora();
try { try {
llama_adapter_lora_init_impl(*model, path_lora, *adapter); llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@ -377,6 +342,6 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr; return nullptr;
} }
void llama_adapter_lora_free(llama_adapter_lora * adapter) { void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
delete adapter; delete adapter;
} }

View File

@ -15,11 +15,11 @@
// //
struct llama_adapter_cvec { struct llama_adapter_cvec {
ggml_tensor * tensor_for(int il) const; struct ggml_tensor * tensor_for(int il) const;
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const; struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
bool apply( int32_t apply(
const llama_model & model, const llama_model & model,
const float * data, const float * data,
size_t len, size_t len,
@ -36,7 +36,7 @@ private:
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<ggml_tensor *> tensors; // per layer std::vector<struct ggml_tensor *> tensors; // per layer
}; };
// //
@ -44,8 +44,8 @@ private:
// //
struct llama_adapter_lora_weight { struct llama_adapter_lora_weight {
ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha // get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) const { float get_scale(float alpha, float adapter_scale) const {
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
} }
llama_adapter_lora_weight() = default; llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {} llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
}; };
struct llama_adapter_lora { struct llama_adapter_lora {
// map tensor name to lora_a_b // map tensor name to lora_a_b
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map; std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
@ -70,7 +70,5 @@ struct llama_adapter_lora {
llama_adapter_lora() = default; llama_adapter_lora() = default;
~llama_adapter_lora() = default; ~llama_adapter_lora() = default;
llama_adapter_lora_weight * get_weight(ggml_tensor * w); llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
}; };
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

View File

@ -6,7 +6,6 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" }, { LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GROK, "grok" },
@ -19,7 +18,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" }, { LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_STABLELM, "stablelm" },
@ -27,8 +25,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" },
{ LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PHIMOE, "phimoe" },
@ -40,7 +36,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
@ -55,7 +50,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_T5ENCODER, "t5encoder" },
@ -64,14 +58,10 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -80,7 +70,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" }, { LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" }, { LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" }, { LLM_KV_GENERAL_VERSION, "general.version" },
@ -107,7 +96,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@ -120,7 +108,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -135,15 +122,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@ -242,35 +223,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
}, },
}, },
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{ {
LLM_ARCH_DECI, LLM_ARCH_DECI,
{ {
@ -474,24 +426,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{ {
LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V2,
{ {
@ -620,45 +554,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_QWEN3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_QWEN3MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{ {
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
{ {
@ -871,27 +766,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}, },
}, },
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{ {
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
{ {
@ -1125,8 +999,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@ -1143,22 +1015,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}, },
}, },
{
LLM_ARCH_PLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_CHATGLM, LLM_ARCH_CHATGLM,
{ {
@ -1177,25 +1033,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
}, },
}, },
{
LLM_ARCH_GLM4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{ {
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
{ {
@ -1380,74 +1217,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_RWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
},
},
{
LLM_ARCH_ARWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
{ {
@ -1527,29 +1296,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
}, },
}, },
{
LLM_ARCH_BAILINGMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -1587,8 +1333,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1615,12 +1376,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1639,9 +1394,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@ -1649,9 +1401,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@ -10,7 +10,6 @@
enum llm_arch { enum llm_arch {
LLM_ARCH_LLAMA, LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI, LLM_ARCH_DECI,
LLM_ARCH_FALCON, LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN, LLM_ARCH_BAICHUAN,
@ -23,7 +22,6 @@ enum llm_arch {
LLM_ARCH_REFACT, LLM_ARCH_REFACT,
LLM_ARCH_BERT, LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM, LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM, LLM_ARCH_STABLELM,
@ -31,8 +29,6 @@ enum llm_arch {
LLM_ARCH_QWEN2, LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE, LLM_ARCH_QWEN2MOE,
LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
LLM_ARCH_PHI3, LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE, LLM_ARCH_PHIMOE,
@ -44,7 +40,6 @@ enum llm_arch {
LLM_ARCH_MINICPM3, LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA, LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
@ -59,7 +54,6 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM, LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
LLM_ARCH_T5, LLM_ARCH_T5,
LLM_ARCH_T5ENCODER, LLM_ARCH_T5ENCODER,
@ -68,14 +62,10 @@ enum llm_arch {
LLM_ARCH_EXAONE, LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -84,7 +74,6 @@ enum llm_kv {
LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_ARCHITECTURE,
LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION, LLM_KV_GENERAL_VERSION,
@ -111,7 +100,6 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC, LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE, LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE, LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_DECODER_START_TOKEN_ID,
@ -124,7 +112,6 @@ enum llm_kv {
LLM_KV_RESIDUAL_SCALE, LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE, LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -139,15 +126,9 @@ enum llm_kv {
LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_DECAY_LORA_RANK,
LLM_KV_ATTENTION_ICLR_LORA_RANK,
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
LLM_KV_ATTENTION_GATE_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_DIMENSION_SECTIONS,
@ -261,8 +242,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_X,
@ -270,20 +249,8 @@ enum llm_tensor {
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2, LLM_TENSOR_TIME_MIX_W2,
LLM_TENSOR_TIME_MIX_A0,
LLM_TENSOR_TIME_MIX_A1,
LLM_TENSOR_TIME_MIX_A2,
LLM_TENSOR_TIME_MIX_V0,
LLM_TENSOR_TIME_MIX_V1,
LLM_TENSOR_TIME_MIX_V2,
LLM_TENSOR_TIME_MIX_G1,
LLM_TENSOR_TIME_MIX_G2,
LLM_TENSOR_TIME_MIX_K_K,
LLM_TENSOR_TIME_MIX_K_A,
LLM_TENSOR_TIME_MIX_R_K,
LLM_TENSOR_TIME_MIX_LERP_X, LLM_TENSOR_TIME_MIX_LERP_X,
LLM_TENSOR_TIME_MIX_LERP_W, LLM_TENSOR_TIME_MIX_LERP_W,
LLM_TENSOR_TIME_MIX_LERP_K, LLM_TENSOR_TIME_MIX_LERP_K,
@ -310,8 +277,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_ATTN_SUB_NORM,

View File

@ -42,9 +42,9 @@ struct llama_sbatch {
bool logits_all; // TODO: remove once lctx.logits_all is removed too bool logits_all; // TODO: remove once lctx.logits_all is removed too
// sorted indices into the batch // sorted indices into the batch
std::vector<int64_t> ids; std::vector<size_t> ids;
// batch indices of the output // batch indices of the output
std::vector<int64_t> out_ids; std::vector<size_t> out_ids;
std::vector<llama_sbatch_seq> seq; std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr; const llama_batch * batch = nullptr;

View File

@ -4,7 +4,6 @@
#include <map> #include <map>
#include <sstream> #include <sstream>
#include <algorithm>
#if __cplusplus >= 202000L #if __cplusplus >= 202000L
#define LU8(x) (const char*)(u8##x) #define LU8(x) (const char*)(u8##x)
@ -50,8 +49,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 }, { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
@ -59,10 +58,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
}; };
llm_chat_template llm_chat_template_from_str(const std::string & name) { llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -82,8 +77,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
if (tmpl_contains("<|im_start|>")) { if (tmpl_contains("<|im_start|>")) {
return tmpl_contains("<|im_sep|>") return tmpl_contains("<|im_sep|>")
? LLM_CHAT_TEMPLATE_PHI_4 ? LLM_CHAT_TEMPLATE_PHI_4
: tmpl_contains("<end_of_utterance>")
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
: LLM_CHAT_TEMPLATE_CHATML; : LLM_CHAT_TEMPLATE_CHATML;
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
if (tmpl_contains("[SYSTEM_PROMPT]")) { if (tmpl_contains("[SYSTEM_PROMPT]")) {
@ -122,12 +115,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} }
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
return LLM_CHAT_TEMPLATE_PHI_3; return LLM_CHAT_TEMPLATE_PHI_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGLM_4;
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
return LLM_CHAT_TEMPLATE_GLMEDGE;
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
return LLM_CHAT_TEMPLATE_ZEPHYR; return LLM_CHAT_TEMPLATE_ZEPHYR;
} else if (tmpl_contains("bos_token + message['role']")) { } else if (tmpl_contains("bos_token + message['role']")) {
@ -156,7 +145,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA_3; return LLM_CHAT_TEMPLATE_LLAMA_3;
} else if (tmpl_contains("[gMASK]sop")) { } else if (tmpl_contains("[gMASK]sop")) {
// chatglm3-6b // chatglm3-6b
return LLM_CHAT_TEMPLATE_CHATGLM_3; return LLM_CHAT_TEMPLATE_CHATGML_3;
} else if (tmpl_contains("[gMASK]<sop>")) {
return LLM_CHAT_TEMPLATE_CHATGML_4;
} else if (tmpl_contains(LU8("<用户>"))) { } else if (tmpl_contains(LU8("<用户>"))) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
return LLM_CHAT_TEMPLATE_MINICPM; return LLM_CHAT_TEMPLATE_MINICPM;
@ -176,12 +167,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_GIGACHAT; return LLM_CHAT_TEMPLATE_GIGACHAT;
} else if (tmpl_contains("<|role_start|>")) { } else if (tmpl_contains("<|role_start|>")) {
return LLM_CHAT_TEMPLATE_MEGREZ; return LLM_CHAT_TEMPLATE_MEGREZ;
} else if (tmpl_contains(" Ассистент:")) {
return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} }
return LLM_CHAT_TEMPLATE_UNKNOWN; return LLM_CHAT_TEMPLATE_UNKNOWN;
} }
@ -437,7 +422,7 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) { } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
// chatglm3-6b // chatglm3-6b
ss << "[gMASK]" << "sop"; ss << "[gMASK]" << "sop";
for (auto message : chat) { for (auto message : chat) {
@ -447,7 +432,7 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
ss << "[gMASK]" << "<sop>"; ss << "[gMASK]" << "<sop>";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
@ -456,6 +441,14 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) { for (auto message : chat) {
@ -573,66 +566,6 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|role_start|>assistant<|role_end|>"; ss << "<|role_start|>assistant<|role_end|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
// Yandex template ("\n\n" is defined as EOT token)
ss << "<s>";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "user") {
ss << " Пользователь: " << chat[i]->content << "\n\n";
} else if (role == "assistant") {
ss << " Ассистент: " << chat[i]->content << "\n\n";
}
}
// Add generation prompt if needed
if (add_ass) {
ss << " Ассистент:[SEP]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
// Bailing (Ling) template
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content;
}
if (add_ass) {
ss << "<role>ASSISTANT</role>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
// Llama 4
for (auto message : chat) {
std::string role(message->role);
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
}
if (add_ass) {
ss << "<|header_start|>assistant<|header_end|>\n\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
// SmolVLM
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "<end_of_utterance>\n";
} else {
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else { } else {
// template not supported // template not supported
return -1; return -1;
@ -651,3 +584,4 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
} }
return (int32_t) LLM_CHAT_TEMPLATES.size(); return (int32_t) LLM_CHAT_TEMPLATES.size();
} }

View File

@ -29,8 +29,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_DEEPSEEK_3, LLM_CHAT_TEMPLATE_DEEPSEEK_3,
LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_COMMAND_R,
LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_LLAMA_3,
LLM_CHAT_TEMPLATE_CHATGLM_3, LLM_CHAT_TEMPLATE_CHATGML_3,
LLM_CHAT_TEMPLATE_CHATGLM_4, LLM_CHAT_TEMPLATE_CHATGML_4,
LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_3,
@ -38,10 +38,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_GIGACHAT,
LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_UNKNOWN, LLM_CHAT_TEMPLATE_UNKNOWN,
}; };

File diff suppressed because it is too large Load Diff

View File

@ -3,212 +3,66 @@
#include "llama.h" #include "llama.h"
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-cparams.h" #include "llama-cparams.h"
#include "llama-graph.h" #include "llama-model.h"
#include "llama-kv-cache.h"
#include "llama-adapter.h" #include "llama-adapter.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include <map> #include <map>
#include <unordered_map>
#include <vector> #include <vector>
#include <set>
struct llama_model;
struct llama_kv_cache;
class llama_io_read_i;
class llama_io_write_i;
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs llama_context(const llama_model & model)
llama_context( : model(model)
const llama_model & model, , t_start_us(model.t_start_us)
llama_context_params params); , t_load_us(model.t_load_us) {}
~llama_context(); const struct llama_model & model;
void synchronize(); struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self;
struct llama_adapter_cvec cvec;
const llama_model & get_model() const; std::unordered_map<struct llama_adapter_lora *, float> lora;
uint32_t n_ctx() const; std::vector<ggml_backend_ptr> backends;
uint32_t n_ctx_per_seq() const; std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
uint32_t n_batch() const;
uint32_t n_ubatch() const;
uint32_t n_seq_max() const;
uint32_t n_threads() const; ggml_backend_t backend_cpu = nullptr;
uint32_t n_threads_batch() const;
llama_kv_cache * get_kv_self(); ggml_threadpool_t threadpool = nullptr;
const llama_kv_cache * get_kv_self() const; ggml_threadpool_t threadpool_batch = nullptr;
void kv_self_update(); bool has_evaluated_once = false;
enum llama_pooling_type pooling_type() const; mutable int64_t t_start_us;
mutable int64_t t_load_us;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
float * get_logits(); mutable int64_t t_compute_start_us = 0;
float * get_logits_ith(int32_t i); mutable int64_t n_queued_tokens = 0;
float * get_embeddings(); mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
float * get_embeddings_ith(int32_t i); mutable int32_t n_eval = 0; // number of eval calls
float * get_embeddings_seq(llama_seq_id seq_id);
void attach_threadpool( // host buffer for the model output (logits and embeddings)
ggml_threadpool_t threadpool, ggml_backend_buffer_ptr buf_output;
ggml_threadpool_t threadpool_batch);
void detach_threadpool();
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);
void set_adapter_lora(
llama_adapter_lora * adapter,
float scale);
bool rm_adapter_lora(
llama_adapter_lora * adapter);
void clear_adapter_lora();
bool apply_adapter_cvec(
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
//
// state save/load
//
size_t state_get_size();
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count);
//
// perf
//
llama_perf_context_data perf_get_data() const;
void perf_reset();
private:
//
// output
//
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
//
// graph
//
int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf) const;
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
//
// members
//
const llama_model & model;
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self;
// TODO: remove
bool logits_all = false;
// decode output (2-dimensional array: [n_outputs][n_vocab]) // decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr; float * logits = nullptr;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
// embeddings output (2-dimensional array: [n_outputs][n_embd]) // embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings size_t embd_size = 0; // capacity (of floats) for embeddings
@ -218,47 +72,57 @@ private:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq; std::map<llama_seq_id, std::vector<float>> embd_seq;
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch // whether we are computing encoder output or decoder output
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers bool is_encoding = false;
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers // TODO: find a better way to accommodate mutli-dimension position encoding methods
// number of position id each token get, 1 for each token in most cases.
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
int n_pos_per_token = 1;
// output of the encoder part of the encoder-decoder models
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_ptr sched; ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
ggml_context_ptr ctx_compute;
ggml_threadpool_t threadpool = nullptr;
ggml_threadpool_t threadpool_batch = nullptr;
ggml_abort_callback abort_callback = nullptr; ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr; void * abort_callback_data = nullptr;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; // input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
// buffer types used for the compute buffer of each backend struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
std::vector<ggml_backend_t> backend_ptrs; struct ggml_tensor * inp_pos; // I32 [n_batch]
std::vector<ggml_backend_buffer_type_t> backend_buft; struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
// memory buffers used to evaluate the model struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
std::vector<uint8_t> buf_compute_meta; struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
// host buffer for the model output (logits and embeddings) struct ggml_tensor * inp_cls; // I32 [n_batch]
ggml_backend_buffer_ptr buf_output; struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
bool has_evaluated_once = false; struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
// perf struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
mutable int64_t t_start_us = 0; struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
mutable int64_t t_load_us = 0;
mutable int64_t t_p_eval_us = 0;
mutable int64_t t_eval_us = 0;
mutable int64_t t_compute_start_us = 0;
mutable int64_t n_queued_tokens = 0;
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
}; };
// TODO: make these methods of llama_context
void llama_set_k_shift(struct llama_context & lctx);
void llama_set_s_copy(struct llama_context & lctx);
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
void llama_output_reorder(struct llama_context & ctx);
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);

View File

@ -29,7 +29,6 @@ struct llama_cparams {
bool offload_kqv; bool offload_kqv;
bool flash_attn; bool flash_attn;
bool no_perf; bool no_perf;
bool warmup;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;

View File

@ -508,7 +508,7 @@ const char * llama_grammar_parser::parse_sequence(
} }
} }
return pos; return pos;
} }
const char * llama_grammar_parser::parse_rule(const char * src) { const char * llama_grammar_parser::parse_rule(const char * src) {
const char * name_end = parse_name(src); const char * name_end = parse_name(src);
@ -532,7 +532,7 @@ const char * llama_grammar_parser::parse_rule(const char * src) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos); throw std::runtime_error(std::string("expecting newline or end at ") + pos);
} }
return parse_space(pos, true); return parse_space(pos, true);
} }
bool llama_grammar_parser::parse(const char * src) { bool llama_grammar_parser::parse(const char * src) {
try { try {
@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .awaiting_trigger = */ false, /* .awaiting_trigger = */ false,
/* .trigger_buffer = */ "", /* .trigger_buffer = */ "",
/* .trigger_tokens = */ {}, /* .trigger_tokens = */ {},
/* .trigger_patterns = */ {}, /* .trigger_words = */ {},
}; };
} }
@ -978,15 +978,19 @@ struct llama_grammar * llama_grammar_init_impl(
const char * grammar_str, const char * grammar_str,
const char * grammar_root, const char * grammar_root,
bool lazy, bool lazy,
const char ** trigger_patterns, const char ** trigger_words,
size_t num_trigger_patterns, size_t num_trigger_words,
const llama_token * trigger_tokens, const llama_token * trigger_tokens,
size_t num_trigger_tokens) { size_t num_trigger_tokens) {
llama_grammar_parser parser; llama_grammar_parser parser;
// if there is a grammar, parse it // if there is a grammar, parse it
// rules will be empty (default) if there are parse errors if (!parser.parse(grammar_str)) {
if (!parser.parse(grammar_str) || parser.rules.empty()) { return nullptr;
}
// will be empty (default) if there are parse errors
if (parser.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__); fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr; return nullptr;
} }
@ -1050,16 +1054,14 @@ struct llama_grammar * llama_grammar_init_impl(
} while (true); } while (true);
std::vector<llama_token> vec_trigger_tokens; std::vector<llama_token> vec_trigger_tokens;
std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns; std::vector<std::string> vec_trigger_words;
for (size_t i = 0; i < num_trigger_tokens; i++) { for (size_t i = 0; i < num_trigger_tokens; i++) {
GGML_ASSERT(trigger_tokens != nullptr); GGML_ASSERT(trigger_tokens != nullptr);
vec_trigger_tokens.push_back(trigger_tokens[i]); vec_trigger_tokens.push_back(trigger_tokens[i]);
} }
for (size_t i = 0; i < num_trigger_patterns; i++) { for (size_t i = 0; i < num_trigger_words; i++) {
GGML_ASSERT(trigger_patterns != nullptr); GGML_ASSERT(trigger_words != nullptr);
auto & trigger = vec_trigger_patterns.emplace_back(); vec_trigger_words.push_back(trigger_words[i]);
trigger.pattern = trigger_patterns[i];
trigger.regex = std::regex(trigger.pattern);
} }
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
@ -1074,7 +1076,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .awaiting_trigger = */ lazy, /* .awaiting_trigger = */ lazy,
/* .trigger_buffer = */ "", /* .trigger_buffer = */ "",
std::move(vec_trigger_tokens), std::move(vec_trigger_tokens),
std::move(vec_trigger_patterns), std::move(vec_trigger_words),
}; };
} }
@ -1087,7 +1089,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
} }
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
auto * result = new llama_grammar { llama_grammar * result = new llama_grammar {
grammar.vocab, grammar.vocab,
grammar.rules, grammar.rules,
grammar.stacks, grammar.stacks,
@ -1096,7 +1098,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.awaiting_trigger, grammar.awaiting_trigger,
grammar.trigger_buffer, grammar.trigger_buffer,
grammar.trigger_tokens, grammar.trigger_tokens,
grammar.trigger_patterns, grammar.trigger_words,
}; };
// redirect elements in stacks to point to new rules // redirect elements in stacks to point to new rules
@ -1171,22 +1173,20 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
return; return;
} else { } else {
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
grammar.trigger_buffer += piece; grammar.trigger_buffer += piece;
for (const auto & word : grammar.trigger_words) {
std::smatch match; auto pos = grammar.trigger_buffer.find(word);
for (const auto & trigger_pattern : grammar.trigger_patterns) { if (pos != std::string::npos) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
grammar.awaiting_trigger = false; grammar.awaiting_trigger = false;
// get from the first match to the end of the string auto constrained_str = grammar.trigger_buffer.substr(pos);
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear(); grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str); llama_grammar_accept_str(grammar, constrained_str);
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
return; return;
} }
} }
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
return; return;
} }
} }

View File

@ -3,7 +3,6 @@
#include "llama.h" #include "llama.h"
#include <map> #include <map>
#include <regex>
#include <string> #include <string>
#include <vector> #include <vector>
@ -106,11 +105,6 @@ struct llama_grammar_parser {
void print(FILE * file); void print(FILE * file);
}; };
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;
};
struct llama_grammar { struct llama_grammar {
// note: allow null vocab for testing (not great) // note: allow null vocab for testing (not great)
const llama_vocab * vocab; const llama_vocab * vocab;
@ -122,16 +116,13 @@ struct llama_grammar {
llama_partial_utf8 partial_utf8; llama_partial_utf8 partial_utf8;
// lazy grammars wait for trigger words or tokens before constraining the sampling. // lazy grammars wait for trigger words or tokens before constraining the sampling.
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// (useful e.g. for tool_choice=required) // (useful e.g. for tool_choice=required)
bool lazy = false; bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only bool awaiting_trigger = false; // Initialized to true for lazy grammars only
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<llama_grammar_trigger_pattern> std::vector<std::string> trigger_words;
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
// string, and the grammar will be given the string from the first match group onwards.
}; };
// //
@ -150,8 +141,8 @@ struct llama_grammar * llama_grammar_init_impl(
const char * grammar_str, const char * grammar_str,
const char * grammar_root, const char * grammar_root,
bool lazy, bool lazy,
const char ** trigger_patterns, const char ** trigger_words,
size_t num_trigger_patterns, size_t num_trigger_words,
const llama_token * trigger_tokens, const llama_token * trigger_tokens,
size_t num_trigger_tokens); size_t num_trigger_tokens);

File diff suppressed because it is too large Load Diff

View File

@ -1,594 +0,0 @@
#pragma once
#include "llama-arch.h"
#include "llama-hparams.h"
#include "llama-adapter.h"
#include <cstdint>
#include <vector>
#include <memory>
#include <set>
#include <functional>
struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
class llama_memory_i;
class llama_kv_cache_unified;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
LLM_GRAPH_TYPE_ENCODER,
LLM_GRAPH_TYPE_DECODER,
};
enum llm_ffn_op_type {
LLM_FFN_SILU,
LLM_FFN_GELU,
LLM_FFN_RELU,
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
};
enum llm_ffn_gate_type {
LLM_FFN_SEQ,
LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
};
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
LLM_NORM_GROUP,
};
// TODO: tmp - need something better to pass the data from the encoder to the decoder
struct llama_cross {
// the output embeddings from the encoder as a ggml tensor
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
//ggml_tensor * t_embd = nullptr;
int64_t n_embd = 0;
int64_t n_enc = 0;
// embeddings data copied to host memory (tmp)
std::vector<float> v_embd;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
//
// llm_graph_input
//
class llm_graph_input_i {
public:
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd() = default;
virtual ~llm_graph_input_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const int64_t n_pos_per_embd = 1;
};
// temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i {
public:
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale;
};
class llm_graph_input_pos_bucket : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
virtual ~llm_graph_input_pos_bucket() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
const llama_hparams & hparams;
};
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
public:
llm_graph_input_out_ids(
const llama_hparams & hparams,
const llama_cparams & cparams,
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
virtual ~llm_graph_input_out_ids() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const int32_t n_outputs;
};
class llm_graph_input_mean : public llm_graph_input_i {
public:
llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_mean() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * mean; // F32 [n_batch, n_batch]
const llama_cparams & cparams;
};
class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cls; // I32 [n_batch]
const llama_cparams & cparams;
};
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
const llama_cross * cross) : cross(cross) {}
virtual ~llm_graph_input_cross_embd() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
const llama_cross * cross;
};
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
public:
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
hparams(hparams),
cparams(cparams) {
}
~llm_graph_input_attn_no_cache() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
};
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified * kv_self) :
hparams(hparams),
cparams(cparams),
kv_self(kv_self) {
}
~llm_graph_input_attn_kv_unified() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
~llm_graph_input_attn_cross() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
const llama_cross * cross = nullptr;
};
//
// llm_graph_result
//
// these objects deliver the result from the graph build process back to the llama_context
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
// specific data, by calling the set_inputs() method
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
// important graph nodes
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
};
//
// llm_graph_context
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
struct llm_graph_params {
ggml_context * ctx;
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
ggml_backend_sched * sched;
ggml_backend * backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
int32_t n_outputs;
const llm_graph_cb & cb;
};
struct llm_graph_context {
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_ctx_per_seq;
const int64_t n_head;
const int64_t n_head_kv;
const int64_t n_embd_head_k;
const int64_t n_embd_k_gqa;
const int64_t n_embd_head_v;
const int64_t n_embd_v_gqa;
const int64_t n_expert;
const int64_t n_expert_used;
const float freq_base;
const float freq_scale;
const float ext_factor;
const float attn_factor;
const float beta_fast;
const float beta_slow;
const float norm_eps;
const float norm_rms_eps;
const int32_t n_tokens;
const int32_t n_outputs;
const int32_t n_ctx_orig; // yarn
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
ggml_context * ctx0 = nullptr;
ggml_backend_sched * sched;
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
const llm_graph_cb & cb_func;
std::unique_ptr<llm_graph_result> res;
llm_graph_context(const llm_graph_params & params);
int64_t n_pos_per_embd() const;
void cb(ggml_tensor * cur, const char * name, int il) const;
//
// common
//
ggml_tensor * build_cvec(
ggml_tensor * cur,
int il) const;
// do mat_mul, while optionally apply lora
ggml_tensor * build_lora_mm(
ggml_tensor * w,
ggml_tensor * cur) const;
// do mat_mul_id, while optionally apply lora
ggml_tensor * build_lora_mm_id(
ggml_tensor * w, // ggml_tensor * as
ggml_tensor * cur, // ggml_tensor * b
ggml_tensor * ids) const;
ggml_tensor * build_norm(
ggml_tensor * cur,
ggml_tensor * mw,
ggml_tensor * mb,
llm_norm_type type,
int il) const;
ggml_tensor * build_ffn(
ggml_tensor * cur,
ggml_tensor * up,
ggml_tensor * up_b,
ggml_tensor * up_s,
ggml_tensor * gate,
ggml_tensor * gate_b,
ggml_tensor * gate_s,
ggml_tensor * down,
ggml_tensor * down_b,
ggml_tensor * down_s,
ggml_tensor * act_scales,
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
int il) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
ggml_tensor * gate_inp,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
bool norm_w,
bool scale_w,
float w_scale,
llama_expert_gating_func_type gating_op,
int il) const;
//
// inputs
//
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
ggml_tensor * build_inp_pos() const;
ggml_tensor * build_inp_attn_scale() const;
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
//
// attention
//
ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
bool v_trans,
float kq_scale) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
//
// recurrent
//
ggml_tensor * build_copy_mask_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const;
//
// pooling
//
void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
};

View File

@ -69,11 +69,3 @@ uint32_t llama_hparams::n_embd_v_s() const {
// corresponds to Mamba's ssm_states size // corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner; return ssm_d_state * ssm_d_inner;
} }
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
}
GGML_ABORT("fatal error");
}

View File

@ -36,17 +36,12 @@ struct llama_hparams {
uint32_t n_layer; uint32_t n_layer;
uint32_t n_rot; uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA) uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0; uint32_t n_expert = 0;
uint32_t n_expert_used = 0; uint32_t n_expert_used = 0;
uint32_t n_rel_attn_bkts = 0; uint32_t n_rel_attn_bkts = 0;
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
uint32_t n_embd_head_k_mla = 0;
uint32_t n_embd_head_v_mla = 0;
// for WavTokenizer // for WavTokenizer
struct llama_hparams_posnet posnet; struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext; struct llama_hparams_convnext convnext;
@ -66,7 +61,6 @@ struct llama_hparams {
float expert_weights_scale = 0.0; float expert_weights_scale = 0.0;
bool expert_weights_norm = false; bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
float f_norm_eps; float f_norm_eps;
float f_norm_rms_eps; float f_norm_rms_eps;
@ -81,16 +75,10 @@ struct llama_hparams {
uint32_t time_decay_extra_dim = 0; uint32_t time_decay_extra_dim = 0;
uint32_t wkv_head_size = 0; uint32_t wkv_head_size = 0;
uint32_t token_shift_count = 2; uint32_t token_shift_count = 2;
uint32_t n_lora_decay = 0;
uint32_t n_lora_iclr = 0;
uint32_t n_lora_value_res_mix = 0;
uint32_t n_lora_gate = 0;
float rope_attn_factor = 1.0f; float rope_attn_factor = 1.0f;
float rope_freq_base_train; float rope_freq_base_train;
float rope_freq_base_train_swa;
float rope_freq_scale_train; float rope_freq_scale_train;
float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn; uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul; float rope_yarn_log_mul;
@ -117,14 +105,6 @@ struct llama_hparams {
bool use_alibi = false; bool use_alibi = false;
bool attn_soft_cap = false; bool attn_soft_cap = false;
uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// needed by encoder-decoder models (e.g. T5, FLAN-T5) // needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141 // ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL; llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@ -153,8 +133,6 @@ struct llama_hparams {
// dimension of the recurrent state embeddings // dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const; uint32_t n_embd_v_s() const;
bool is_swa(uint32_t il) const;
}; };
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable"); static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

Some files were not shown because too many files have changed in this diff Show More