mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-01 15:00:31 +02:00
Compare commits
143 Commits
v1.7.4-pre
...
fix_vs_sdl
Author | SHA1 | Date | |
---|---|---|---|
00ddb10fe2 | |||
b3a6018bbf | |||
d682e15090 | |||
46d07b9c85 | |||
33ea03f131 | |||
dbcc669e1a | |||
16245b35e4 | |||
898c0cb9d1 | |||
eb9e5032c4 | |||
cadfc50eab | |||
3f91832352 | |||
cff8868b5f | |||
90e3c5fc40 | |||
e0f4cef867 | |||
234460987e | |||
b8ab126343 | |||
edc5d9267c | |||
344b98a44f | |||
dbeb7916b8 | |||
fad2806352 | |||
9906792ec3 | |||
c49ee07ff4 | |||
f8a831779e | |||
85451e3612 | |||
43c744ce8b | |||
fc2e44490d | |||
f41fdad200 | |||
80fa576254 | |||
75e7d0585e | |||
682a6f5f87 | |||
115716d109 | |||
b2cfef655b | |||
22e3df0afa | |||
028511d349 | |||
70c4038842 | |||
8639c003a9 | |||
d5d831da65 | |||
7230a6e1c8 | |||
a160fa0f3a | |||
0282ad8fd1 | |||
9e467815d4 | |||
727891d9bf | |||
c262dc80e2 | |||
30767b4c4e | |||
16eeb31933 | |||
ba523d5e22 | |||
3736706139 | |||
58640aa456 | |||
5183a05e56 | |||
0dcada42d4 | |||
d507b4cebe | |||
90171055f3 | |||
668306ff2b | |||
fdc21fc87b | |||
7183a1eb72 | |||
09f3c66648 | |||
62e2414620 | |||
de49024e49 | |||
db6383094c | |||
164f13c6a9 | |||
02aa86230a | |||
54a2ee648f | |||
9700cfb0a3 | |||
8e0143e205 | |||
f12559d590 | |||
589b40810a | |||
7ffcd05267 | |||
7a423f1c00 | |||
99b011a9f5 | |||
19d95f9f9a | |||
d5ef1737d8 | |||
1deb41f0e7 | |||
2425caf4fd | |||
a4b00bcaaf | |||
cdb8aa2f2e | |||
06209f6683 | |||
c3235bd81e | |||
262d0abc87 | |||
124eec1664 | |||
b08c3a88c8 | |||
0afce25a69 | |||
acdbe58631 | |||
09fabffdf5 | |||
3988d6396b | |||
c8c63eeec0 | |||
abf7f24410 | |||
341f5c28e6 | |||
5377099524 | |||
dcbb375779 | |||
4334c71aed | |||
e875a82473 | |||
507e230f1e | |||
eb68324c86 | |||
e940fbf283 | |||
35d0e02c72 | |||
45d3faf961 | |||
2ab2eb5110 | |||
b82d305282 | |||
885e31368d | |||
8a9ad7844d | |||
eb874b3a3c | |||
eb78e3a3f1 | |||
ece3ff88f6 | |||
9366544991 | |||
95583942ed | |||
2e93cb6a2f | |||
de5cd60d1c | |||
3fcba3e58b | |||
cea5f1c52f | |||
2112462db4 | |||
fc84ecd445 | |||
8de1e99907 | |||
499af9294a | |||
bcf937c216 | |||
b8d90953d7 | |||
60a422147b | |||
3387415bad | |||
536ca3ec89 | |||
a4bb983190 | |||
39c205f555 | |||
6d502f33dc | |||
5ea27d089d | |||
1462d92588 | |||
7ba1a41f47 | |||
5ea088636f | |||
f32ddb3b1c | |||
79b75ece03 | |||
6348d73e55 | |||
fb36a1538a | |||
c81b8b910b | |||
85b60f31d0 | |||
227b5ffa36 | |||
36a64a253f | |||
c84b83c370 | |||
5136fd92c2 | |||
7d55637f0b | |||
0994506054 | |||
53c9a3a984 | |||
ed09075ca0 | |||
f07a81aa9f | |||
4183517076 | |||
f4668169a0 | |||
944ce49439 |
@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build
|
||||
ARG CUDA_DOCKER_ARCH=all
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git cmake libsdl2-dev wget
|
||||
apt-get install -y build-essential git cmake libsdl2-dev wget git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
@ -17,7 +17,7 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
ENV GGML_CUDA=1
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential libsdl2-dev wget cmake \
|
||||
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
# Ref: https://stackoverflow.com/a/53464012
|
||||
@ -33,7 +33,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg wget cmake \
|
||||
apt-get install -y curl ffmpeg wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
|
@ -2,7 +2,7 @@ FROM ubuntu:22.04 AS build
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential wget cmake \
|
||||
apt-get install -y build-essential wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY .. .
|
||||
@ -12,7 +12,7 @@ FROM ubuntu:22.04 AS runtime
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg libsdl2-dev wget cmake \
|
||||
apt-get install -y curl ffmpeg libsdl2-dev wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
|
4
.github/workflows/bindings-go.yml
vendored
4
.github/workflows/bindings-go.yml
vendored
@ -10,8 +10,8 @@ on:
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-22:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
|
4
.github/workflows/bindings-ruby.yml
vendored
4
.github/workflows/bindings-ruby.yml
vendored
@ -42,8 +42,8 @@ on:
|
||||
- examples/dr_wav.h
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-22:
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
working-directory: bindings/ruby
|
||||
|
300
.github/workflows/build.yml
vendored
300
.github/workflows/build.yml
vendored
@ -1,18 +1,28 @@
|
||||
name: CI
|
||||
on: [push, pull_request]
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
ubuntu_image: "ubuntu:22.04"
|
||||
VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite"
|
||||
|
||||
jobs:
|
||||
ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-22:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
|
||||
arch: [linux/amd64, linux/ppc64le]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -28,10 +38,62 @@ jobs:
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev cmake
|
||||
apt install -y build-essential libsdl2-dev cmake git
|
||||
cmake -B build
|
||||
cmake --build build --config Release -j $(nproc)'
|
||||
|
||||
ubuntu-22-arm64:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [linux/arm64]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Build ${{ matrix.arch }}
|
||||
run: |
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev cmake git
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
|
||||
cmake --build build --config Release -j $(nproc)'
|
||||
|
||||
ubuntu-22-arm-v7:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [linux/arm/v7]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Build ${{ matrix.arch }}
|
||||
run: |
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev cmake git
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
|
||||
cmake --build build --config Release -j $(nproc)'
|
||||
|
||||
macOS-latest:
|
||||
runs-on: macOS-latest
|
||||
|
||||
@ -67,14 +129,14 @@ jobs:
|
||||
# cmake -B build
|
||||
# cmake --build build --config Release
|
||||
|
||||
ubuntu-latest-gcc:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-22-gcc:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
|
||||
arch: [linux/amd64, linux/ppc64le]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -90,13 +152,69 @@ jobs:
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
apt install -y build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
|
||||
ubuntu-latest-clang:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-22-gcc-arm64:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
arch: [linux/arm64]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Build ${{ matrix.arch }}
|
||||
run: |
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
|
||||
ubuntu-22-gcc-arm-v7:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
build: [Debug, Release]
|
||||
arch: [linux/arm/v7]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Build ${{ matrix.arch }}
|
||||
run: |
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
|
||||
ubuntu-22-clang:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@ -121,13 +239,13 @@ jobs:
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y clang build-essential cmake libsdl2-dev
|
||||
apt install -y clang build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
|
||||
ubuntu-latest-gcc-sanitized:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-22-gcc-sanitized:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@ -149,7 +267,7 @@ jobs:
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake
|
||||
apt install -y build-essential cmake git
|
||||
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
@ -184,12 +302,12 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp git
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt install intel-oneapi-mkl-devel
|
||||
sudo apt install intel-oneapi-mkl-devel git
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@ -234,7 +352,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp git
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
@ -275,6 +393,7 @@ jobs:
|
||||
msystem: ${{matrix.sys}}
|
||||
install: >-
|
||||
base-devel
|
||||
git
|
||||
mingw-w64-${{matrix.env}}-toolchain
|
||||
mingw-w64-${{matrix.env}}-cmake
|
||||
mingw-w64-${{matrix.env}}-SDL2
|
||||
@ -430,75 +549,76 @@ jobs:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
# TODO: fix and re-enable
|
||||
# windows-cublas:
|
||||
# runs-on: windows-2019
|
||||
#
|
||||
# strategy:
|
||||
# matrix:
|
||||
# build: [Release]
|
||||
# arch: [x64]
|
||||
# cublas: [ON]
|
||||
# sdl2: [ON]
|
||||
# cuda-toolkit: [12.2.0, 11.8.0]
|
||||
# include:
|
||||
# - arch: x64
|
||||
# s2arc: x64
|
||||
# - sdl2: ON
|
||||
# s2ver: 2.28.5
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Add msbuild to PATH
|
||||
# uses: microsoft/setup-msbuild@v2
|
||||
#
|
||||
# - name: Install CUDA Toolkit
|
||||
# id: cuda-toolkit
|
||||
# uses: Jimver/cuda-toolkit@v0.2.15
|
||||
# with:
|
||||
# cuda: '${{ matrix.cuda-toolkit }}'
|
||||
#
|
||||
# - name: Fetch SDL2 and set SDL2_DIR
|
||||
# if: matrix.sdl2 == 'ON'
|
||||
# run: |
|
||||
# C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
# 7z x sdl2.zip
|
||||
# echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
#
|
||||
# - name: Configure
|
||||
# run: >
|
||||
# cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
# -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
# -DGGML_CUDA=${{ matrix.cublas }}
|
||||
# -DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
#
|
||||
# - name: Build ${{ matrix.cuda-toolkit }}
|
||||
# run: |
|
||||
# cd ./build
|
||||
# cmake --build . --config ${{ matrix.build }}
|
||||
#
|
||||
# - name: Copy CUDA DLLs
|
||||
# run: >
|
||||
# Copy-Item -PassThru
|
||||
# -Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
|
||||
# -Include cudart64_*,cublas64_*,cublasLt64_*
|
||||
# -Destination build/bin/${{ matrix.build }}
|
||||
#
|
||||
# - name: Copy SDL2.dll
|
||||
# if: matrix.sdl2 == 'ON'
|
||||
# run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
#
|
||||
# - name: Upload binaries
|
||||
# if: matrix.sdl2 == 'ON'
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
|
||||
# path: build/bin/${{ matrix.build }}
|
||||
windows-cublas:
|
||||
runs-on: windows-2019
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [x64]
|
||||
cublas: [ON]
|
||||
sdl2: [ON]
|
||||
cuda-toolkit: [12.2.0, 11.8.0]
|
||||
include:
|
||||
- arch: x64
|
||||
sdl2: ON
|
||||
sdl2_ver: 2.28.5
|
||||
steps:
|
||||
- name: Clone repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v2
|
||||
|
||||
- name: Install CUDA Toolkit
|
||||
id: cuda-toolkit
|
||||
uses: Jimver/cuda-toolkit@v0.2.15
|
||||
with:
|
||||
cuda: '${{ matrix.cuda-toolkit }}'
|
||||
|
||||
- name: Install 7-Zip
|
||||
run: choco install 7zip -y
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt
|
||||
|
||||
- name: Configure CMake
|
||||
shell: cmd
|
||||
run: |
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }} ^
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
|
||||
-DGGML_CUDA=${{ matrix.cublas }} ^
|
||||
-DCMAKE_CUDA_ARCHITECTURES=all ^
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
|
||||
-DSDL2_DIR="%SDL2_DIR%"
|
||||
|
||||
- name: Build Project
|
||||
shell: cmd
|
||||
run: |
|
||||
cd ./build
|
||||
cmake --build . --config ${{ matrix.build }}
|
||||
|
||||
- name: Copy CUDA DLLs
|
||||
run: |
|
||||
Get-ChildItem "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/" -Filter "*.dll" |
|
||||
Copy-Item -Destination "build/bin/${{ matrix.build }}"
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@ -558,14 +678,14 @@ jobs:
|
||||
run: |
|
||||
xcodebuild -scheme whisper-Package -destination 'generic/platform=iOS'
|
||||
|
||||
#- name: Build objc example
|
||||
# run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos build
|
||||
- name: Build objc example
|
||||
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO build
|
||||
|
||||
- name: Build swiftui example
|
||||
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build
|
||||
|
||||
android:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -595,7 +715,7 @@ jobs:
|
||||
|
||||
# TODO: disable because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/11019444420/job/30627193602
|
||||
# android_java:
|
||||
# runs-on: ubuntu-latest
|
||||
# runs-on: ubuntu-22.04
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
@ -664,7 +784,7 @@ jobs:
|
||||
# PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
|
||||
|
||||
quantize:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
4
.github/workflows/docker.yml
vendored
4
.github/workflows/docker.yml
vendored
@ -11,13 +11,13 @@ jobs:
|
||||
name: Push Docker image to Docker Hub
|
||||
if: github.event.pull_request.draft == false
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64,linux/arm64" }
|
||||
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
|
||||
#TODO: the cuda image keeps failing - disable for now
|
||||
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
|
||||
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
||||
|
6
.github/workflows/examples.yml
vendored
6
.github/workflows/examples.yml
vendored
@ -10,8 +10,8 @@ on:
|
||||
- whisper.h
|
||||
|
||||
jobs:
|
||||
addon_node-ubuntu-latest:
|
||||
runs-on: ubuntu-latest
|
||||
addon_node-ubuntu-22:
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [ 16.x, 18.x ]
|
||||
@ -22,7 +22,7 @@ jobs:
|
||||
- name: Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install build-essential
|
||||
sudo apt-get install build-essential git
|
||||
sudo apt-get install cmake
|
||||
sudo apt-get install libsdl2-dev
|
||||
|
||||
|
211
AUTHORS
211
AUTHORS
@ -1,34 +1,51 @@
|
||||
# date: Tue Apr 9 20:27:03 EEST 2024
|
||||
# date: Tue Feb 4 13:03:35 EET 2025
|
||||
# this file is auto-generated by scripts/gen-authors.sh
|
||||
|
||||
0/0 <zero@imaskeleton.me>
|
||||
0cc4m <picard12@live.de>
|
||||
0xsourcecode <134374803+0xsourcecode@users.noreply.github.com>
|
||||
65a <10104049+65a@users.noreply.github.com>
|
||||
AIWintermuteAI <32562299+AIWintermuteAI@users.noreply.github.com>
|
||||
AT <manyoso@users.noreply.github.com>
|
||||
Aarni Koskela <akx@iki.fi>
|
||||
Aaron Pham <29749331+aarnphm@users.noreply.github.com>
|
||||
Aaron Taylor <aaron@exphat.com>
|
||||
Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
|
||||
Abitofevrything <54505189+abitofevrything@users.noreply.github.com>
|
||||
Adam Jones <domdomegg+git@gmail.com>
|
||||
Adrien Gallouët <adrien@gallouet.fr>
|
||||
Adrien Gallouët <angt@huggingface.co>
|
||||
AfryMask <AfryMask@163.com>
|
||||
Ahmad Bilal <ahmad.bilal@empglabs.com>
|
||||
Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com>
|
||||
AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>
|
||||
AidanBeltonS <aidan.belton@codeplay.com>
|
||||
Akarshan Biswas <akarshan.biswas@gmail.com>
|
||||
Akarshan Biswas <akarshanbiswas@fedoraproject.org>
|
||||
Akash Mahajan <akash7190@gmail.com>
|
||||
Akash Mahajan <akashmjn@stanford.edu>
|
||||
Al Hoang <3811822-hoanga@users.noreply.gitlab.com>
|
||||
Alan <unknown>
|
||||
Albert Jin <albert.jin@gmail.com>
|
||||
Alberto Cabrera Pérez <alberto.cabrera@codeplay.com>
|
||||
Alberto Cabrera Pérez <alberto.cabrera@intel.com>
|
||||
Aleksander Andrzejewski <18704749+aleksanderandrzejewski@users.noreply.github.com>
|
||||
Alex Azarov <alex@azarov.by>
|
||||
Alex Bacart <13940752+alex-bacart@users.noreply.github.com>
|
||||
Alex Evgrashin <aevgrashin@yandex.ru>
|
||||
Alex O'Connell <35843486+acon96@users.noreply.github.com>
|
||||
Alexandr Graschenkov <alexandr.graschenkov91@gmail.com>
|
||||
Alexandru Mariuti <alex@mariuti.com>
|
||||
Alexey Kharlamov <alexey@kharlamov.biz>
|
||||
Alfredo Montesinos <alfredo.montesinos@g.austincc.edu>
|
||||
Ali Alameh <ali.alameh@isae.edu.lb>
|
||||
Alter <0x7c48@gmail.com>
|
||||
Ananta Bastola <anantarajbastola@gmail.com>
|
||||
Andreas Kieslinger <47689530+aendk@users.noreply.github.com>
|
||||
Andreas Lubbe <git@lubbe.org>
|
||||
Andreu Huguet <andreuhuguet@gmail.com>
|
||||
Andrew Huynh <a5thuynh@gmail.com>
|
||||
Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com>
|
||||
Andrew S <andrews54757@gmail.com>
|
||||
Andy Maloney <asmaloney@gmail.com>
|
||||
Anton Kostin <masguit42@users.noreply.github.com>
|
||||
@ -40,8 +57,11 @@ AustinMroz <austinmroz@utexas.edu>
|
||||
Avik Sengupta <avik@sengupta.net>
|
||||
Bader-eddine Ouaich <49657842+baderouaich@users.noreply.github.com>
|
||||
Baffin Lee <baffinlee@gmail.com>
|
||||
Ben Ashbaugh <ben.ashbaugh@intel.com>
|
||||
Ben Nortier <bjnortier@gmail.com>
|
||||
Benjamin Heiniger <benjamin.heiniger@bluewin.ch>
|
||||
Bernhard M. Wiedemann <githubbmwprimary@lsmod.de>
|
||||
Binozo <70137898+Binozo@users.noreply.github.com>
|
||||
Bo-Yi Wu <appleboy.tw@gmail.com>
|
||||
Boris Bliznioukov <blib@mail.com>
|
||||
Borislav Stanimirov <b.stanimirov@abv.bg>
|
||||
@ -49,47 +69,86 @@ Brad Murray <59848399+bradmurray-dt@users.noreply.github.com>
|
||||
Brian Murray <brian@bmurray.ca>
|
||||
CRD716 <crd716@gmail.com>
|
||||
Canis Lupus <Canis-UK@users.noreply.github.com>
|
||||
Carlos Zoido <mrgalleta@gmail.com>
|
||||
Carolinabanana <140120812+Carolinabanana@users.noreply.github.com>
|
||||
CarterLi999 <664681047@qq.com>
|
||||
ChangSeok Oh <shivamidow@users.noreply.github.com>
|
||||
Changyeon Kim <cyzero.kim@samsung.com>
|
||||
Chaoqun <27287694+OpenWaygate@users.noreply.github.com>
|
||||
Charles Xu <63788048+chaxu01@users.noreply.github.com>
|
||||
Charles Xu <charles.xu@arm.com>
|
||||
Chen Xi <xi2.chen@intel.com>
|
||||
Chen Xi <xixichen08@foxmail.com>
|
||||
Chenguang Li <87689256+noemotiovon@users.noreply.github.com>
|
||||
Chia-Hsiang Cheng <88014292+garychia@users.noreply.github.com>
|
||||
Chidi Williams <williamschidi1@gmail.com>
|
||||
Chris Elrod <elrodc@gmail.com>
|
||||
Christian <12550267+iceychris@users.noreply.github.com>
|
||||
Christian Kastner <ckk@kvr.at>
|
||||
Clifford Heath <clifford.heath@gmail.com>
|
||||
Clint Herron <hanclinto@gmail.com>
|
||||
Colin <github@whoisc.cc>
|
||||
Conrad Kramer <conrad@conradkramer.com>
|
||||
Corey Earwood <iamcgn+github@gmail.com>
|
||||
CrispStrobe <154636388+CrispStrobe@users.noreply.github.com>
|
||||
DAN™ <dranger003@gmail.com>
|
||||
DGdev91 <DGdev91@users.noreply.github.com>
|
||||
Damian Czaja <trojan295@protonmail.com>
|
||||
Dan Johansson <164997844+eddnjjn@users.noreply.github.com>
|
||||
Dan Johansson <dan.johansson@arm.com>
|
||||
Daniel Bevenius <daniel.bevenius@gmail.com>
|
||||
Daniel Valdivia <18384552+dvaldivia@users.noreply.github.com>
|
||||
Daniel Ziegenberg <daniel@ziegenberg.at>
|
||||
Daniele <57776841+daniandtheweb@users.noreply.github.com>
|
||||
Dave <dave-fl@users.noreply.github.com>
|
||||
Dave Airlie <airlied@gmail.com>
|
||||
Dave Airlie <airlied@redhat.com>
|
||||
Daven Sanassy <daven@vochlea.co.uk>
|
||||
David <dnhkng@gmail.com>
|
||||
David Thorpe <djt@mutablelogic.com>
|
||||
DavidKorczynski <david@adalogics.com>
|
||||
Davidson Francis <davidsondfgl@gmail.com>
|
||||
Dener Stassun <denerstassun@gmail.com>
|
||||
Dibakar Gope <dibakar.gope@arm.com>
|
||||
Didzis Gosko <didzis@users.noreply.github.com>
|
||||
Diego Devesa <slarengh@gmail.com>
|
||||
Digipom <admin@digipom.com>
|
||||
Dimo <dimo@ieee.org>
|
||||
Djip007 <3705339+Djip007@users.noreply.github.com>
|
||||
Djip007 <djip.perois@free.fr>
|
||||
Dody Suria Wijaya <dodysw@gmail.com>
|
||||
Dou Xinpeng <15529241576@163.com>
|
||||
Dou Xinpeng <81913537+Dou-Git@users.noreply.github.com>
|
||||
Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com>
|
||||
Duncan McConnell <ddmcconnell4@gmail.com>
|
||||
Egor Egorov <me@egorfine.com>
|
||||
Elkana Bardugo <ttv200@gmail.com>
|
||||
Emmanuel Schmidbauer <eschmidbauer@gmail.com>
|
||||
Engininja2 <139037756+Engininja2@users.noreply.github.com>
|
||||
Eric Curtin <ericcurtin17@gmail.com>
|
||||
Eric Swanson <eswanson@alloscomp.com>
|
||||
Eric Tendian <erictendian@gmail.com>
|
||||
Eric Zhang <34133756+EZForever@users.noreply.github.com>
|
||||
Erik Scholz <Green-Sky@users.noreply.github.com>
|
||||
Evan Jones <evan.q.jones@gmail.com>
|
||||
Evan Martin <evan.martin@gmail.com>
|
||||
Eve <139727413+netrunnereve@users.noreply.github.com>
|
||||
Evgeny Kuznetsov <evgeny@kuznetsov.md>
|
||||
F1L1P <78918286+F1L1Pv2@users.noreply.github.com>
|
||||
Faisal Zaghloul <quic_fzaghlou@quicinc.com>
|
||||
Fangjun Kuang <csukuangfj@gmail.com>
|
||||
Felix <stenbackfelix@gmail.com>
|
||||
Finn Voorhees <finnvoorhees@gmail.com>
|
||||
FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com>
|
||||
FlippFuzz <41221030+FlippFuzz@users.noreply.github.com>
|
||||
Frankie Robertson <frankier@users.noreply.github.com>
|
||||
Gang Chen <goncha@gmail.com>
|
||||
Gavin Cai <gavin1818@hotmail.com>
|
||||
George Hindle <george@georgehindle.com>
|
||||
Georgi Gerganov <ggerganov@gmail.com>
|
||||
Gilad S <7817232+giladgd@users.noreply.github.com>
|
||||
Gilad S <giladgd@users.noreply.github.com>
|
||||
Gilad S. <7817232+giladgd@users.noreply.github.com>
|
||||
GitAritron <103900385+GitAritron@users.noreply.github.com>
|
||||
GiviMAD <GiviMAD@users.noreply.github.com>
|
||||
Gleicon Moraes <gleicon@gmail.com>
|
||||
@ -98,41 +157,66 @@ Guillaume Wenzek <gwenzek@users.noreply.github.com>
|
||||
HY. Kelvin Lee <34256578+hykelvinlee42@users.noreply.github.com>
|
||||
Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com>
|
||||
Hang <bebound@gmail.com>
|
||||
Haus1 <haus.xda@gmail.com>
|
||||
Herman Semenov <GermanAizek@yandex.ru>
|
||||
HimariO <dsfhe49854@gmail.com>
|
||||
Hong Bo PENG <penghb@cn.ibm.com>
|
||||
Hrishikesh Barman <geekodour@users.noreply.github.com>
|
||||
Hugo <hugo@whynothugo.nl>
|
||||
Ian Bicking <ian@ianbicking.org>
|
||||
Ian Bull <irbull@eclipsesource.com>
|
||||
Ihar Hrachyshka <ihrachys@redhat.com>
|
||||
Ikko Ashimine <eltociear@gmail.com>
|
||||
Ikko Eltociear Ashimine <eltociear@gmail.com>
|
||||
InconsolableCellist <23345188+InconsolableCellist@users.noreply.github.com>
|
||||
Ismatulla Mansurov <47342870+sapoepsilon@users.noreply.github.com>
|
||||
Ivan <nekotekina@gmail.com>
|
||||
Ivan Filipov <159561759+vanaka11@users.noreply.github.com>
|
||||
Ivan Gorin <ivangorin21@gmail.com>
|
||||
Ivo von Putzer Reibegg <ivo.putzer@gmail.com>
|
||||
JJ <103335846+computerscienceiscool@users.noreply.github.com>
|
||||
Jack Mousseau <jmousseau@users.noreply.github.com>
|
||||
JacobLinCool <jacoblincool@gmail.com>
|
||||
Jakub Ráček <blizzcz@gmail.com>
|
||||
Jared Van Bortel <jared@nomic.ai>
|
||||
Jay Binks <jaybinks@gmail.com>
|
||||
Jayant <jayantyadav202@gmail.com>
|
||||
Jeff Bolz <jbolz@nvidia.com>
|
||||
Jeroen Mostert <jeroen.mostert@cm.com>
|
||||
Jhen-Jie Hong <developer@jhen.me>
|
||||
Jhen-Jie Hong <iainst0409@gmail.com>
|
||||
JidongZhang-THU <1119708529@qq.com>
|
||||
Jo Liss <joliss42@gmail.com>
|
||||
Joe Todd <joe.todd@codeplay.com>
|
||||
Johan <jr.raffin@gmail.com>
|
||||
Johannes Gäßler <johannesg@5d6.de>
|
||||
John Balis <phobossystems@gmail.com>
|
||||
JohnnyB <jboero@users.noreply.github.com>
|
||||
Jonathan Soo <jcsoo@agora.com>
|
||||
Jonno <1160532+razodactyl@users.noreply.github.com>
|
||||
Joonas Pihlajamaa <joonas.pihlajamaa@iki.fi>
|
||||
Jose <34888496+Jerry-Master@users.noreply.github.com>
|
||||
Josh Bleecher Snyder <josharian@gmail.com>
|
||||
Josscii <jossciiweiyi@gmail.com>
|
||||
Judd <foldl@users.noreply.github.com>
|
||||
Jumper775 <78500318+jumpers775@users.noreply.github.com>
|
||||
Jun Hee Yoo <contact.jhyoo@gmail.com>
|
||||
Junil Kim <logyourself@gmail.com>
|
||||
Justina Cho <justcho5@gmail.com>
|
||||
Justine Tunney <jtunney@gmail.com>
|
||||
Justine Tunney <jtunney@mozilla.com>
|
||||
KITAITI Makoto <KitaitiMakoto@gmail.com>
|
||||
KP Kaiser <kirk@zothcorp.com>
|
||||
Kamilake <exjang0@gmail.com>
|
||||
Karol Kontny <82021046+kkontny@users.noreply.github.com>
|
||||
Karthick <j.karthic2004@gmail.com>
|
||||
Kartik Saranathan <278928+Kartiku@users.noreply.github.com>
|
||||
Kasumi <90275229+kasumi-1@users.noreply.github.com>
|
||||
Kawrakow <48489457+ikawrakow@users.noreply.github.com>
|
||||
Kendrick Taylor <kendrick@circuitsix.com>
|
||||
Kevin Brothaler <admin@digipom.com>
|
||||
Kevin Gibbons <bakkot@gmail.com>
|
||||
Konosuke Sakai <konosuke@konosuke.work>
|
||||
Konstantin Zhuravlyov <konstantin.zhuravlyov@amd.com>
|
||||
Kreijstal <rainb@tfwno.gf>
|
||||
Kylin <56434533+KyL0N@users.noreply.github.com>
|
||||
@ -147,56 +231,110 @@ Luis Herrera <herrera-luis@users.noreply.github.com>
|
||||
Lukas Rist <glaslos@gmail.com>
|
||||
M. A. Ali <73258591+MightyStud@users.noreply.github.com>
|
||||
M. Eren Akbiyik <erenakbiyik@gmail.com>
|
||||
Ma Mingfei <mingfei.ma@intel.com>
|
||||
Maciek <maciek.mab122@gmail.com>
|
||||
Mahesh Madhav <67384846+heshpdx@users.noreply.github.com>
|
||||
Marcin Mielniczuk <marmistrz.dev@zoho.eu>
|
||||
Mark Karpelès <MagicalTux@users.noreply.github.com>
|
||||
Mark Zhuang <zhuangqiubin@gmail.com>
|
||||
Markus Tavenrath <mtavenrath@users.noreply.github.com>
|
||||
Martin Delille <martin@delille.org>
|
||||
Martin Warnaar <martinwarnaar@gmail.com>
|
||||
Masaya, Kato <62578291+msy-kato@users.noreply.github.com>
|
||||
Matheus de Sousa <23645013+keyehzy@users.noreply.github.com>
|
||||
Mathieu Baudier <mbaudier@argeo.org>
|
||||
Mathijs de Bruin <mathijs@mathijsfietst.nl>
|
||||
Matija Pevec <mightymatth@users.noreply.github.com>
|
||||
Matt Stephenson <mstephenson6@users.noreply.github.com>
|
||||
Max Krasnyansky <max.krasnyansky@gmail.com>
|
||||
Max Krasnyansky <quic_maxk@quicinc.com>
|
||||
Maximiliano Levi <8160966+maxilevi@users.noreply.github.com>
|
||||
Meng, Hengyu <hengyu.meng@intel.com>
|
||||
Mengqing Cao <cmq0113@163.com>
|
||||
Michael Podvitskiy <podvitskiymichael@gmail.com>
|
||||
Michael Rienstra <mrienstra@gmail.com>
|
||||
Mikhail Grigorev <sleuthhound@gmail.com>
|
||||
Mohammadreza Hendiani <hendiani.mohammadreza@gmail.com>
|
||||
Mohit Agarwal <mohit@sdf.org>
|
||||
Molly Sophia <mollysophia379@gmail.com>
|
||||
Murilo Santana <mvrilo@gmail.com>
|
||||
NETZkultur GmbH <mulholland@netzkultur.de>
|
||||
Natsu <chino@hotococoa.moe>
|
||||
Neil Chudleigh <nchudleigh@users.noreply.github.com>
|
||||
Neo Zhang <14088817+arthw@users.noreply.github.com>
|
||||
Neo Zhang Jianyu <jianyu.zhang@intel.com>
|
||||
Neuman Vong <neuman.vong@gmail.com>
|
||||
Nicholai Tukanov <nicholaitukanov@gmail.com>
|
||||
Nicholas Albion <nalbion@yahoo.com>
|
||||
Nico Bosshard <nico@bosshome.ch>
|
||||
Nicolò Scipione <nicolo.scipione@codeplay.com>
|
||||
Niels Mayer <Niels.Mayer@gmail.com>
|
||||
Nikita Sarychev <42014488+sARY77@users.noreply.github.com>
|
||||
Nikolaj Olsson <nikse.dk@gmail.com>
|
||||
Okabintaro <103938900+Okabintaro@users.noreply.github.com>
|
||||
Oleg Sidorov <me@whitebox.io>
|
||||
Oleg Sidorov <oleg@sidorov.nl>
|
||||
Olivier Chafik <ochafik@users.noreply.github.com>
|
||||
Ondrej Kokes <ondrej.kokes@gmail.com>
|
||||
Ouadie EL FAROUKI <ouadie.elfarouki@codeplay.com>
|
||||
PAB <pierreantoine.bannier@gmail.com>
|
||||
Paul Tsochantaris <ptsochantaris@icloud.com>
|
||||
Pedro Probst <pprobst@insiberia.net>
|
||||
Peng <hzp1024@qq.com>
|
||||
Peter <peter277@users.noreply.github.com>
|
||||
Philipp Zabel <philipp.zabel@gmail.com>
|
||||
Philippe Normand <phil@base-art.net>
|
||||
Philippe Normand <philn@igalia.com>
|
||||
Plamen Minev <pacominev@gmail.com>
|
||||
Prashant Vithule <119530321+Vithulep@users.noreply.github.com>
|
||||
Przemysław Pawełczyk <przemoc@gmail.com>
|
||||
Qianhe Chen <54462604+chenqianhe@users.noreply.github.com>
|
||||
R0CKSTAR <xiaodong.ye@mthreads.com>
|
||||
R0CKSTAR <yeahdongcn@gmail.com>
|
||||
Radoslav Gerganov <rgerganov@gmail.com>
|
||||
Radosław Gryta <radek.gryta@gmail.com>
|
||||
Rahul Vadhyar <107788610+RahulVadhyar@users.noreply.github.com>
|
||||
Raiya Araki <83504221+rai62@users.noreply.github.com>
|
||||
Reinforce-II <fate@eastal.com>
|
||||
Reinis Muiznieks <muiznieks.reinis@gmail.com>
|
||||
RelatedTitle <r3latedtitle@gmail.com>
|
||||
Rémy Oudompheng <oudomphe@phare.normalesup.org>
|
||||
RhinoDevel <RhinoDevel@users.noreply.github.com>
|
||||
Rich Jones <miserlou@gmail.com>
|
||||
Robert Ormandi <52251610+ormandi@users.noreply.github.com>
|
||||
Robin <robin.xw@hotmail.com>
|
||||
Roddur Dasgupta <roddurd@gmail.com>
|
||||
Roland Rabien <figbug@gmail.com>
|
||||
Romain Biessy <romain.biessy@codeplay.com>
|
||||
Ronsor <ronsor@ronsor.pw>
|
||||
Rotem Dan <rotemdan@gmail.com>
|
||||
Ryan Hitchman <hitchmanr@gmail.com>
|
||||
Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com>
|
||||
RyanChang <ftes90015@gmail.com>
|
||||
SRHMorris <69468379+SRHMorris@users.noreply.github.com>
|
||||
SXX <sxx1136965276@gmail.com>
|
||||
Sacha Arbonel <sacha.arbonel@hotmail.fr>
|
||||
Salman Faroz <stsfaroz@gmail.com>
|
||||
Salvatore Mesoraca <s.mesoraca16@gmail.com>
|
||||
Sam <49637763+Onlyartist9@users.noreply.github.com>
|
||||
Sam Pullara <spullara@gmail.com>
|
||||
Samuel Durante <44513615+samueldurantes@users.noreply.github.com>
|
||||
Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
|
||||
Sandro Hanea <40202887+sandrohanea@users.noreply.github.com>
|
||||
Sergio López <slp@redhat.com>
|
||||
Sergio López <slp@sinrega.org>
|
||||
Shanshan Shen <467638484@qq.com>
|
||||
Shijie <821898965@qq.com>
|
||||
Shupei Fan <dymarkfan@outlook.com>
|
||||
Siddharth Ramakrishnan <srr2141@columbia.edu>
|
||||
Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
|
||||
Simon Moisselin <simon.moisstoll@gmail.com>
|
||||
Sindre Sorhus <sindresorhus@gmail.com>
|
||||
Slava Primenko <primenko.s@gmail.com>
|
||||
Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com>
|
||||
Stavros Panakakis <53979866+Stavrospanakakis@users.noreply.github.com>
|
||||
Stefan Sydow <s.sydow@heinlein-video.de>
|
||||
Stefan Sydow <stefan@sydow.email>
|
||||
Syahmi Azhar <prsyahmi@gmail.com>
|
||||
Syed Jafri <syedjafri97@gmail.com>
|
||||
Sơn Phan Trung <phantrungson17@gmail.com>
|
||||
@ -205,37 +343,63 @@ Takeshi Inoue <inoue.takeshi@gmail.com>
|
||||
Tamotsu Takahashi <ttakah+github@gmail.com>
|
||||
Taras Glek <taras@thegp.com>
|
||||
Tauseef Mohiuddin <35351464+tauseefmohammed2@users.noreply.github.com>
|
||||
Thamster <Thamster@users.noreply.github.com>
|
||||
Thijs Raymakers <thijs@raymakers.nl>
|
||||
Thomas Fitzsimmons <fitzsim@fitzsim.org>
|
||||
Tiago Fassoni <tiagofassoni@users.noreply.github.com>
|
||||
Tienshiao Ma <tienshiao@tienshiao.org>
|
||||
Tim Miller <drasticactions@users.noreply.github.com>
|
||||
Timothy Cronin <40186632+4imothy@users.noreply.github.com>
|
||||
Tobrun <tobrun.van.nuland@gmail.com>
|
||||
Todd <taf2@users.noreply.github.com>
|
||||
Toliver <teejae@gmail.com>
|
||||
Tong Li <31761981+litongjava@users.noreply.github.com>
|
||||
Tony Wasserka <4840017+neobrain@users.noreply.github.com>
|
||||
Topping1 <78745143+Topping1@users.noreply.github.com>
|
||||
Travis Cline <travis.cline@gmail.com>
|
||||
UEXTM.com <84163508+uextm@users.noreply.github.com>
|
||||
UsernamesLame <156965854+UsernamesLame@users.noreply.github.com>
|
||||
Vadim Peretokin <vperetokin@hey.com>
|
||||
Valentin Gosu <1454649+valenting@users.noreply.github.com>
|
||||
Vin Misra <vinith@alum.mit.edu>
|
||||
Vulcan <93451215+trholding@users.noreply.github.com>
|
||||
WhiteOlivierus <36532695+WhiteOlivierus@users.noreply.github.com>
|
||||
William Tambellini <william.tambellini@gmail.com>
|
||||
William Tambellini <wtambellini@sdl.com>
|
||||
Wilson Silva <wilson.dsigns@gmail.com>
|
||||
Xiang (Kevin) Li <kevinli020508@gmail.com>
|
||||
Xiao-Yong Jin <jinxiaoyong@gmail.com>
|
||||
XiaotaoChen <chenxiaotao1234@gmail.com>
|
||||
Xingchen Song(宋星辰) <xingchensong1996@163.com>
|
||||
Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com>
|
||||
Xuan Son Nguyen <thichthat@gmail.com>
|
||||
Yajing Tang <phillis@google.com>
|
||||
Yang Shen <aplshenyang@gmail.com>
|
||||
Yunès <jean.baptiste.yunes@free.fr>
|
||||
Yuri Khrustalev <ykhrustalev@users.noreply.github.com>
|
||||
Yusuf Redžić <48274562+redzic@users.noreply.github.com>
|
||||
ZaBlazzingZephyrus <119159668+blazingzephyr@users.noreply.github.com>
|
||||
Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com>
|
||||
Zhiyuan Li <lizhiyuan@uniartisan.com>
|
||||
Zhiyuan Li <uniartisan2017@gmail.com>
|
||||
Zigfrid Zvezdin <ziggerZZ@gmail.com>
|
||||
Zollner <24618122+Zolliner@users.noreply.github.com>
|
||||
a3sh <38979186+A3shTnT@users.noreply.github.com>
|
||||
ag2s20150909 <19373730+ag2s20150909@users.noreply.github.com>
|
||||
agray3 <agray3@users.noreply.github.com>
|
||||
ai-at-home <149282006+ai-at-home@users.noreply.github.com>
|
||||
aldorof <aldorof@users.noreply.github.com>
|
||||
alonfaraj <alonfaraj@gmail.com>
|
||||
amd-dwang <dong.wang@amd.com>
|
||||
amritahs-ibm <amritahs@linux.vnet.ibm.com>
|
||||
andypayne <apayne@gmail.com>
|
||||
ardfork <134447697+ardfork@users.noreply.github.com>
|
||||
arizhih <40765267+arizhih@users.noreply.github.com>
|
||||
automaticcat <daogiatuank54@gmail.com>
|
||||
bandoti <141645996+bandoti@users.noreply.github.com>
|
||||
be-next <jerome.ramette@gmail.com>
|
||||
bert hubert <bert@hubertnet.nl>
|
||||
billyct <billy_allen@126.com>
|
||||
bmwl <brian.marshall@tolko.com>
|
||||
bobqianic <129547291+bobqianic@users.noreply.github.com>
|
||||
bocytko <bocytko+github@gmail.com>
|
||||
@ -248,7 +412,9 @@ byte-6174 <88070277+byte-6174@users.noreply.github.com>
|
||||
cdosoftei <ciprian.dosoftei@gmail.com>
|
||||
clach04 <Chris.Clark@actian.com>
|
||||
compilade <113953597+compilade@users.noreply.github.com>
|
||||
compilade <git@compilade.net>
|
||||
conradg <conradjgodfrey@gmail.com>
|
||||
crummyh <elijah@crums.us>
|
||||
ddpasa <112642920+ddpasa@users.noreply.github.com>
|
||||
denersc <denerstassun@gmail.com>
|
||||
dscripka <dscripka@users.noreply.github.com>
|
||||
@ -256,28 +422,55 @@ duthils <duthils@duthils.net>
|
||||
ecneladis <ecneladis@users.noreply.github.com>
|
||||
faker <nspyia2002@gmail.com>
|
||||
fitzsim <fitzsim@fitzsim.org>
|
||||
fj-y-saito <85871716+fj-y-saito@users.noreply.github.com>
|
||||
fraxy-v <65565042+fraxy-v@users.noreply.github.com>
|
||||
genevera (she/her) <genevera@users.noreply.github.com>
|
||||
geniusnut <geniusnut@gmail.com>
|
||||
gilbertgong <gilbert.gong@gmail.com>
|
||||
gn64 <yukikaze.jp@gmail.com>
|
||||
goldwaving <77494627+goldwaving@users.noreply.github.com>
|
||||
greeshmay <greeshmay@gmail.com>
|
||||
haopeng <657407891@qq.com>
|
||||
hipudding <huafengchun@gmail.com>
|
||||
hsinhoyeh <yhh92u@gmail.com>
|
||||
hydai <z54981220@gmail.com>
|
||||
iamthad <thadeus.j.fleming@gmail.com>
|
||||
issixx <46835150+issixx@users.noreply.github.com>
|
||||
james wolf <contractorwolf@hotmail.com>
|
||||
jdomke <28772296+jdomke@users.noreply.github.com>
|
||||
jettoblack <jettoblack@gmail.com>
|
||||
jiez <373447296@qq.com>
|
||||
joecryptotoo <80373433+joecryptotoo@users.noreply.github.com>
|
||||
jorismertz <35079666+jorismertz@users.noreply.github.com>
|
||||
junchao-loongson <68935141+junchao-loongson@users.noreply.github.com>
|
||||
junkfood <69683722+JunkFood02@users.noreply.github.com>
|
||||
jwijffels <jwijffels@bnosac.be>
|
||||
k.h.lai <adrian.k.h.lai@outlook.com>
|
||||
kamranjon <kamranjon@gmail.com>
|
||||
katsu560 <katsu560oo-@docomo.ne.jp>
|
||||
kennethge <57784063+kenneth-ge@users.noreply.github.com>
|
||||
keyehzy <msamuel@aluno.puc-rio.br>
|
||||
kunnis <kunnis@users.noreply.github.com>
|
||||
l3utterfly <gc.pthzfoldr@gmail.com>
|
||||
leejet <leejet714@gmail.com>
|
||||
leo-pony <nengjunma@outlook.com>
|
||||
lhez <quic_lih@quicinc.com>
|
||||
litong <31761981+litongjava@users.noreply.github.com>
|
||||
liuwei-git <14815172+liuwei-git@users.noreply.github.com>
|
||||
lnyan <lkwq007@gmail.com>
|
||||
luoyu-intel <yu.luo@intel.com>
|
||||
m.bell <m.bell@techsmith.com>
|
||||
mahorozte <41834471+mahorozte@users.noreply.github.com>
|
||||
mashizora <30516315+mashizora@users.noreply.github.com>
|
||||
matt23654 <matthew.webber@protonmail.com>
|
||||
matteo <matteogeniaccio@yahoo.it>
|
||||
mgrachten <maarten@grachten.eu>
|
||||
mkiol <mkiol@users.noreply.github.com>
|
||||
mky_coder <47767389+mkycoder@users.noreply.github.com>
|
||||
novag <7754358+novag@users.noreply.github.com>
|
||||
pajowu <pajowu@pajowu.de>
|
||||
pengxin99 <pengxin.yuan@intel.com>
|
||||
petterreinholdtsen <pere-github@hungry.com>
|
||||
polarmoon <90010972+polarmoon@users.noreply.github.com>
|
||||
rlapray <lapray.romain@gmail.com>
|
||||
sandrohanea <40202887+sandrohanea@users.noreply.github.com>
|
||||
@ -287,15 +480,31 @@ shikokuchuo <53399081+shikokuchuo@users.noreply.github.com>
|
||||
slaren <slarengh@gmail.com>
|
||||
slashlib <slashlib@users.noreply.github.com>
|
||||
snadampal <87143774+snadampal@users.noreply.github.com>
|
||||
someone13574 <81528246+someone13574@users.noreply.github.com>
|
||||
st-gr <38470677+st-gr@users.noreply.github.com>
|
||||
stduhpf <stephduh@live.fr>
|
||||
stormofice <58337328+stormofice@users.noreply.github.com>
|
||||
texmex76 <40733439+texmex76@users.noreply.github.com>
|
||||
thefinaldegree <thefinaldegree@gmail.com>
|
||||
thewh1teagle <61390950+thewh1teagle@users.noreply.github.com>
|
||||
toboil-features <160222185+toboil-features@users.noreply.github.com>
|
||||
trixirt <trix@redhat.com>
|
||||
ulatekh <ulatekh@yahoo.com>
|
||||
undef <undefdev@gmail.com>
|
||||
uvos <devnull@uvos.xyz>
|
||||
uvos <philipp@uvos.xyz>
|
||||
valVk <valVk@users.noreply.github.com>
|
||||
venkr <venkateshrameshkumar+1@gmail.com>
|
||||
vicalloy <zbirder@gmail.com>
|
||||
wangshuai09 <391746016@qq.com>
|
||||
woachk <24752637+woachk@users.noreply.github.com>
|
||||
xctan <axunlei@gmail.com>
|
||||
xdrudis <xavierdrudis@yahoo.es>
|
||||
yuri@FreeBSD <yuri@FreeBSD>
|
||||
zhangjixiong <code.zjx@gmail.com>
|
||||
zhentaoyu <zhentao.yu@intel.com>
|
||||
zhouwg <6889919+zhouwg@users.noreply.github.com>
|
||||
zhouwg <zhouwg2000@gmail.com>
|
||||
谢乃闻 <sienaiwun@users.noreply.github.com>
|
||||
布客飞龙 <562826179@qq.com>
|
||||
Артём Земляк <azemlyak@smart-consulting.ru>
|
||||
|
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
|
||||
project("whisper.cpp" C CXX)
|
||||
project("whisper.cpp" VERSION 1.7.3)
|
||||
project("whisper.cpp" VERSION 1.7.4)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(SOVERSION 1)
|
||||
|
2
Makefile
2
Makefile
@ -64,6 +64,6 @@ tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 larg
|
||||
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
|
||||
echo "----------------------------------------------" ; \
|
||||
echo "" ; \
|
||||
./build/bin/main -m models/ggml-$@.bin -f $$f ; \
|
||||
./build/bin/whisper-cli -m models/ggml-$@.bin -f $$f ; \
|
||||
echo "" ; \
|
||||
done
|
||||
|
59
README.md
59
README.md
@ -7,14 +7,17 @@
|
||||
[](https://conan.io/center/whisper-cpp)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.7.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.3) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
> [!NOTE]
|
||||
> New maintenance roadmap: https://github.com/ggerganov/whisper.cpp/discussions/2788
|
||||
|
||||
Stable: [v1.7.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](#core-ml-support)
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- [VSX intrinsics support for POWER architectures](#power-vsx-intrinsics)
|
||||
- Mixed F16 / F32 precision
|
||||
- [Integer quantization support](#quantization)
|
||||
- Zero memory allocations at runtime
|
||||
@ -136,6 +139,20 @@ make -j large-v3-turbo
|
||||
| medium | 1.5 GiB | ~2.1 GB |
|
||||
| large | 2.9 GiB | ~3.9 GB |
|
||||
|
||||
## POWER VSX Intrinsics
|
||||
|
||||
`whisper.cpp` supports POWER architectures and includes code which
|
||||
significantly speeds operation on Linux running on POWER9/10, making it
|
||||
capable of faster-than-realtime transcription on underclocked Raptor
|
||||
Talos II. Ensure you have a BLAS package installed, and replace the
|
||||
standard cmake setup with:
|
||||
|
||||
```bash
|
||||
# build with GGML_BLAS defined
|
||||
cmake -B build -DGGML_BLAS=1
|
||||
cmake --build build --config Release
|
||||
./build/bin/whisper-cli [ .. etc .. ]
|
||||
|
||||
## Quantization
|
||||
|
||||
`whisper.cpp` supports integer quantization of the Whisper `ggml` models.
|
||||
@ -293,7 +310,7 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
||||
cached for the next run.
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
||||
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
||||
|
||||
## NVIDIA GPU support
|
||||
|
||||
@ -360,6 +377,38 @@ Run the inference examples as usual, for example:
|
||||
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
|
||||
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
|
||||
|
||||
## Docker
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker must be installed and running on your system.
|
||||
- Create a folder to store big models & intermediate files (ex. /whisper/models)
|
||||
|
||||
### Images
|
||||
|
||||
We have two Docker images available for this project:
|
||||
|
||||
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
|
||||
### Usage
|
||||
|
||||
```shell
|
||||
# download model and persist it in a local folder
|
||||
docker run -it --rm \
|
||||
-v path/to/models:/models \
|
||||
whisper.cpp:main "./models/download-ggml-model.sh base /models"
|
||||
# transcribe an audio file
|
||||
docker run -it --rm \
|
||||
-v path/to/models:/models \
|
||||
-v path/to/audios:/audios \
|
||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
|
||||
# transcribe an audio file in samples folder
|
||||
docker run -it --rm \
|
||||
-v path/to/models:/models \
|
||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
|
||||
```
|
||||
|
||||
## Installing with Conan
|
||||
|
||||
You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command:
|
||||
@ -381,9 +430,9 @@ The [stream](examples/stream) tool samples the audio every half a second and run
|
||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake -B build -DWHISPER_SDL2=ON
|
||||
cmake --build build --config Release
|
||||
./build/bin/stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
|
||||
|
@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
|
||||
}
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public CBool suppress_non_speech_tokens;
|
||||
public CBool suppress_nst;
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public void suppressNonSpeechTokens(boolean enable) {
|
||||
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
|
||||
suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Initial decoding temperature. */
|
||||
@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
|
||||
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
||||
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
|
||||
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.7.3",
|
||||
"version": "1.7.4",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
4
bindings/ruby/.gitignore
vendored
4
bindings/ruby/.gitignore
vendored
@ -1,5 +1,3 @@
|
||||
LICENSE
|
||||
pkg/
|
||||
lib/whisper.so
|
||||
lib/whisper.bundle
|
||||
lib/whisper.dll
|
||||
lib/whisper.*
|
||||
|
@ -24,14 +24,15 @@ require "whisper"
|
||||
|
||||
whisper = Whisper::Context.new("base")
|
||||
|
||||
params = Whisper::Params.new
|
||||
params.language = "en"
|
||||
params.offset = 10_000
|
||||
params.duration = 60_000
|
||||
params.max_text_tokens = 300
|
||||
params.translate = true
|
||||
params.print_timestamps = false
|
||||
params.initial_prompt = "Initial prompt here."
|
||||
params = Whisper::Params.new(
|
||||
language: "en",
|
||||
offset: 10_000,
|
||||
duration: 60_000,
|
||||
max_text_tokens: 300,
|
||||
translate: true,
|
||||
print_timestamps: false,
|
||||
initial_prompt: "Initial prompt here."
|
||||
)
|
||||
|
||||
whisper.transcribe("path/to/audio.wav", params) do |whole_text|
|
||||
puts whole_text
|
||||
@ -60,10 +61,10 @@ You also can use shorthand for pre-converted models:
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
```
|
||||
|
||||
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
|
||||
You can see the list of prepared model names by `Whisper::Model.pre_converted_models.keys`:
|
||||
|
||||
```ruby
|
||||
puts Whisper::Model.preconverted_model_names
|
||||
puts Whisper::Model.pre_converted_models.keys
|
||||
# tiny
|
||||
# tiny.en
|
||||
# tiny-q5_1
|
||||
@ -87,8 +88,9 @@ whisper = Whisper::Context.new("path/to/your/model.bin")
|
||||
Or, you can download model files:
|
||||
|
||||
```ruby
|
||||
model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin")
|
||||
whisper = Whisper::Context.new(model_uri)
|
||||
whisper = Whisper::Context.new("https://example.net/uri/of/your/model.bin")
|
||||
# Or
|
||||
whisper = Whisper::Context.new(URI("https://example.net/uri/of/your/model.bin"))
|
||||
```
|
||||
|
||||
See [models][] page for details.
|
||||
@ -112,18 +114,18 @@ def format_time(time_ms)
|
||||
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
whisper.transcribe("path/to/audio.wav", params)
|
||||
|
||||
whisper.each_segment.with_index do |segment, index|
|
||||
line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
|
||||
nth: index + 1,
|
||||
st: format_time(segment.start_time),
|
||||
ed: format_time(segment.end_time),
|
||||
text: segment.text
|
||||
}
|
||||
line << " (speaker turned)" if segment.speaker_next_turn?
|
||||
puts line
|
||||
end
|
||||
whisper
|
||||
.transcribe("path/to/audio.wav", params)
|
||||
.each_segment.with_index do |segment, index|
|
||||
line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
|
||||
nth: index + 1,
|
||||
st: format_time(segment.start_time),
|
||||
ed: format_time(segment.end_time),
|
||||
text: segment.text
|
||||
}
|
||||
line << " (speaker turned)" if segment.speaker_next_turn?
|
||||
puts line
|
||||
end
|
||||
|
||||
```
|
||||
|
||||
@ -214,13 +216,25 @@ reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :
|
||||
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
|
||||
|
||||
whisper = Whisper::Context.new("base")
|
||||
whisper.full(Whisper::Params.new, samples)
|
||||
whisper.each_segment do |segment|
|
||||
puts segment.text
|
||||
end
|
||||
whisper
|
||||
.full(Whisper::Params.new, samples)
|
||||
.each_segment do |segment|
|
||||
puts segment.text
|
||||
end
|
||||
```
|
||||
|
||||
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
|
||||
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
|
||||
|
||||
Development
|
||||
-----------
|
||||
|
||||
% git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
% cd whisper.cpp/bindings/ruby
|
||||
% rake test
|
||||
|
||||
First call of `rake test` builds an extension and downloads a model for testing. After that, you add tests in `tests` directory and modify `ext/ruby_whisper.cpp`.
|
||||
|
||||
If something seems wrong on build, running `rake clean` solves some cases.
|
||||
|
||||
License
|
||||
-------
|
||||
|
@ -18,9 +18,11 @@ EXTSOURCES.each do |src|
|
||||
end
|
||||
|
||||
CLEAN.include SOURCES
|
||||
CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
|
||||
CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
|
||||
|
||||
task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
|
||||
SRC = FileList["ext/*.{c,cpp,h}"]
|
||||
|
||||
task build: SOURCES
|
||||
|
||||
directory "pkg"
|
||||
CLOBBER.include "pkg"
|
||||
@ -29,14 +31,14 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
|
||||
SO_FILE = File.join("ext", LIB_NAME)
|
||||
LIB_FILE = File.join("lib", LIB_NAME)
|
||||
|
||||
file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t|
|
||||
Dir.chdir "ext" do
|
||||
file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
|
||||
chdir "ext" do
|
||||
ruby "extconf.rb"
|
||||
end
|
||||
end
|
||||
|
||||
file SO_FILE => "ext/Makefile" do |t|
|
||||
Dir.chdir "ext" do
|
||||
chdir "ext" do
|
||||
sh "make"
|
||||
end
|
||||
end
|
||||
@ -54,7 +56,7 @@ end
|
||||
|
||||
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
|
||||
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
|
||||
Dir.chdir "tests/jfk_reader" do
|
||||
chdir "tests/jfk_reader" do
|
||||
ruby "extconf.rb"
|
||||
sh "make"
|
||||
end
|
||||
|
12
bindings/ruby/ext/.gitignore
vendored
12
bindings/ruby/ext/.gitignore
vendored
@ -4,10 +4,8 @@ whisper.bundle
|
||||
whisper.dll
|
||||
scripts/get-flags.mk
|
||||
*.o
|
||||
*.c
|
||||
*.cpp
|
||||
*.h
|
||||
*.m
|
||||
*.metal
|
||||
!ruby_whisper.cpp
|
||||
!ruby_whisper.h
|
||||
/*/**/*.c
|
||||
/*/**/*.cpp
|
||||
/*/**/*.h
|
||||
/*/**/*.m
|
||||
/*/**/*.metal
|
||||
|
@ -174,7 +174,14 @@ $OBJ_WHISPER <<
|
||||
'src/whisper.o'
|
||||
|
||||
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
|
||||
$objs << "ruby_whisper.o"
|
||||
$objs <<
|
||||
"ruby_whisper.o" <<
|
||||
"ruby_whisper_context.o" <<
|
||||
"ruby_whisper_transcribe.o" <<
|
||||
"ruby_whisper_params.o" <<
|
||||
"ruby_whisper_error.o" <<
|
||||
"ruby_whisper_segment.o" <<
|
||||
"ruby_whisper_model.o"
|
||||
|
||||
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
|
||||
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
|
||||
|
164
bindings/ruby/ext/ruby_whisper.c
Normal file
164
bindings/ruby/ext/ruby_whisper.c
Normal file
@ -0,0 +1,164 @@
|
||||
#include <ruby.h>
|
||||
#include <ruby/memory_view.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
VALUE mWhisper;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
VALUE eError;
|
||||
|
||||
VALUE cSegment;
|
||||
VALUE cModel;
|
||||
|
||||
ID id_to_s;
|
||||
ID id_call;
|
||||
ID id___method__;
|
||||
ID id_to_enum;
|
||||
ID id_length;
|
||||
ID id_next;
|
||||
ID id_new;
|
||||
ID id_to_path;
|
||||
ID id_URI;
|
||||
ID id_pre_converted_models;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
|
||||
// High level API
|
||||
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
|
||||
|
||||
extern void init_ruby_whisper_context(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_params(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_error(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
|
||||
extern void init_ruby_whisper_model(VALUE *mWhisper);
|
||||
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_max_id -> Integer
|
||||
*/
|
||||
static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
|
||||
return INT2NUM(whisper_lang_max_id());
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_id(lang_name) -> Integer
|
||||
*/
|
||||
static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
|
||||
const char * lang_str = StringValueCStr(lang);
|
||||
const int id = whisper_lang_id(lang_str);
|
||||
if (-1 == id) {
|
||||
rb_raise(rb_eArgError, "language not found: %s", lang_str);
|
||||
}
|
||||
return INT2NUM(id);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_str(lang_id) -> String
|
||||
*/
|
||||
static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
|
||||
const int lang_id = NUM2INT(id);
|
||||
const char * str = whisper_lang_str(lang_id);
|
||||
if (NULL == str) {
|
||||
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
|
||||
}
|
||||
return rb_str_new2(str);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* lang_str(lang_id) -> String
|
||||
*/
|
||||
static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
|
||||
const int lang_id = NUM2INT(id);
|
||||
const char * str_full = whisper_lang_str_full(lang_id);
|
||||
if (NULL == str_full) {
|
||||
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
|
||||
}
|
||||
return rb_str_new2(str_full);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
|
||||
is_log_callback_finalized = true;
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static void
|
||||
ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
|
||||
if (is_log_callback_finalized) {
|
||||
return;
|
||||
}
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
VALUE udata = rb_iv_get(mWhisper, "user_data");
|
||||
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
|
||||
*/
|
||||
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
|
||||
VALUE old_callback = rb_iv_get(self, "log_callback");
|
||||
if (!NIL_P(old_callback)) {
|
||||
rb_undefine_finalizer(old_callback);
|
||||
}
|
||||
|
||||
rb_iv_set(self, "log_callback", log_callback);
|
||||
rb_iv_set(self, "user_data", user_data);
|
||||
|
||||
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
|
||||
rb_define_finalizer(log_callback, finalize_log_callback);
|
||||
|
||||
whisper_log_set(ruby_whisper_log_callback, NULL);
|
||||
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
|
||||
rb_gc_mark(rwm->context);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_model_allocate(VALUE klass) {
|
||||
ruby_whisper_model *rwm;
|
||||
rwm = ALLOC(ruby_whisper_model);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
|
||||
}
|
||||
|
||||
void Init_whisper() {
|
||||
id_to_s = rb_intern("to_s");
|
||||
id_call = rb_intern("call");
|
||||
id___method__ = rb_intern("__method__");
|
||||
id_to_enum = rb_intern("to_enum");
|
||||
id_length = rb_intern("length");
|
||||
id_next = rb_intern("next");
|
||||
id_new = rb_intern("new");
|
||||
id_to_path = rb_intern("to_path");
|
||||
id_URI = rb_intern("URI");
|
||||
id_pre_converted_models = rb_intern("pre_converted_models");
|
||||
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
|
||||
|
||||
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
|
||||
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
|
||||
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
|
||||
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
|
||||
|
||||
init_ruby_whisper_context(&mWhisper);
|
||||
init_ruby_whisper_params(&mWhisper);
|
||||
init_ruby_whisper_error(&mWhisper);
|
||||
init_ruby_whisper_segment(&mWhisper, &cContext);
|
||||
init_ruby_whisper_model(&mWhisper);
|
||||
|
||||
rb_require("whisper/model/uri");
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
#ifndef __RUBY_WHISPER_H
|
||||
#define __RUBY_WHISPER_H
|
||||
#ifndef RUBY_WHISPER_H
|
||||
#define RUBY_WHISPER_H
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
@ -22,4 +22,13 @@ typedef struct {
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
} ruby_whisper_params;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
int index;
|
||||
} ruby_whisper_segment;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
} ruby_whisper_model;
|
||||
|
||||
#endif
|
||||
|
613
bindings/ruby/ext/ruby_whisper_context.c
Normal file
613
bindings/ruby/ext/ruby_whisper_context.c
Normal file
@ -0,0 +1,613 @@
|
||||
#include <ruby.h>
|
||||
#include <ruby/memory_view.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
extern ID id_to_s;
|
||||
extern ID id___method__;
|
||||
extern ID id_to_enum;
|
||||
extern ID id_length;
|
||||
extern ID id_next;
|
||||
extern ID id_new;
|
||||
extern ID id_to_path;
|
||||
extern ID id_URI;
|
||||
extern ID id_pre_converted_models;
|
||||
|
||||
extern VALUE cContext;
|
||||
extern VALUE eError;
|
||||
extern VALUE cModel;
|
||||
|
||||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_initialize(VALUE context);
|
||||
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
|
||||
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
static void
|
||||
ruby_whisper_free(ruby_whisper *rw)
|
||||
{
|
||||
if (rw->context) {
|
||||
whisper_free(rw->context);
|
||||
rw->context = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
rb_whisper_mark(ruby_whisper *rw)
|
||||
{
|
||||
// call rb_gc_mark on any ruby references in rw
|
||||
}
|
||||
|
||||
void
|
||||
rb_whisper_free(ruby_whisper *rw)
|
||||
{
|
||||
ruby_whisper_free(rw);
|
||||
free(rw);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
rw = ALLOC(ruby_whisper);
|
||||
rw->context = NULL;
|
||||
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* new("base.en") -> Whisper::Context
|
||||
* new("path/to/model.bin") -> Whisper::Context
|
||||
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
VALUE whisper_model_file_path;
|
||||
|
||||
// TODO: we can support init from buffer here too maybe another ruby object to expose
|
||||
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
|
||||
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
|
||||
if (!NIL_P(pre_converted_model)) {
|
||||
whisper_model_file_path = pre_converted_model;
|
||||
}
|
||||
if (TYPE(whisper_model_file_path) == T_STRING) {
|
||||
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
|
||||
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
|
||||
VALUE uri_class = rb_const_get(cModel, id_URI);
|
||||
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
|
||||
}
|
||||
}
|
||||
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
|
||||
VALUE uri_class = rb_const_get(cModel, id_URI);
|
||||
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
|
||||
}
|
||||
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
|
||||
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
|
||||
}
|
||||
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
|
||||
}
|
||||
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
|
||||
if (rw->context == NULL) {
|
||||
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_vocab -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_vocab(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_vocab(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_audio_ctx -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_audio_state -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_audio_state(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_state(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_audio_head -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_audio_head(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_head(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_audio_layer -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_audio_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_layer(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_text_ctx -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_text_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_ctx(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_text_state -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_text_state(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_state(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_text_head -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_text_head(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_head(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_text_layer -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_text_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_layer(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_n_mels -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_n_mels(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_mels(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_ftype -> Integer
|
||||
*/
|
||||
VALUE ruby_whisper_model_ftype(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_ftype(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model_type -> String
|
||||
*/
|
||||
VALUE ruby_whisper_model_type(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return rb_str_new2(whisper_model_type_readable(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*
|
||||
* call-seq:
|
||||
* full(params, samples, n_samples) -> nil
|
||||
* full(params, samples) -> nil
|
||||
*
|
||||
* The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
|
||||
*/
|
||||
VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
if (argc < 2 || argc > 3) {
|
||||
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
|
||||
}
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
VALUE params = argv[0];
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
VALUE samples = argv[1];
|
||||
int n_samples;
|
||||
rb_memory_view_t view;
|
||||
const bool memory_view_available_p = rb_memory_view_available_p(samples);
|
||||
if (argc == 3) {
|
||||
n_samples = NUM2INT(argv[2]);
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
if (RARRAY_LEN(samples) < n_samples) {
|
||||
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
|
||||
}
|
||||
}
|
||||
// Should check when samples.respond_to?(:length)?
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
n_samples = RARRAY_LEN(samples);
|
||||
} else if (memory_view_available_p) {
|
||||
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
|
||||
view.obj = Qnil;
|
||||
rb_raise(rb_eArgError, "unable to get a memory view");
|
||||
}
|
||||
n_samples = view.byte_size / view.item_size;
|
||||
} else if (rb_respond_to(samples, id_length)) {
|
||||
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
|
||||
} else {
|
||||
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
|
||||
}
|
||||
}
|
||||
float * c_samples = (float *)malloc(n_samples * sizeof(float));
|
||||
if (memory_view_available_p) {
|
||||
c_samples = (float *)view.data;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
|
||||
}
|
||||
} else {
|
||||
// TODO: use rb_block_call
|
||||
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
// TODO: check if iter is exhausted and raise ArgumentError appropriately
|
||||
VALUE sample = rb_funcall(iter, id_next, 0);
|
||||
c_samples[i] = RFLOAT_VALUE(sample);
|
||||
}
|
||||
}
|
||||
}
|
||||
register_callbacks(rwp, &self);
|
||||
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
|
||||
if (0 == result) {
|
||||
return self;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
||||
* Result is stored in the default state of the context
|
||||
* Not thread safe if executed in parallel on the same context.
|
||||
* It seems this approach can offer some speedup in some cases.
|
||||
* However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
*
|
||||
* call-seq:
|
||||
* full_parallel(params, samples) -> nil
|
||||
* full_parallel(params, samples, n_samples) -> nil
|
||||
* full_parallel(params, samples, n_samples, n_processors) -> nil
|
||||
* full_parallel(params, samples, nil, n_processors) -> nil
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
|
||||
{
|
||||
if (argc < 2 || argc > 4) {
|
||||
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
|
||||
}
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
VALUE params = argv[0];
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
VALUE samples = argv[1];
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
rb_memory_view_t view;
|
||||
const bool memory_view_available_p = rb_memory_view_available_p(samples);
|
||||
switch (argc) {
|
||||
case 2:
|
||||
n_processors = 1;
|
||||
break;
|
||||
case 3:
|
||||
n_processors = 1;
|
||||
break;
|
||||
case 4:
|
||||
n_processors = NUM2INT(argv[3]);
|
||||
break;
|
||||
}
|
||||
if (argc >= 3 && !NIL_P(argv[2])) {
|
||||
n_samples = NUM2INT(argv[2]);
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
if (RARRAY_LEN(samples) < n_samples) {
|
||||
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
|
||||
}
|
||||
}
|
||||
// Should check when samples.respond_to?(:length)?
|
||||
} else if (memory_view_available_p) {
|
||||
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
|
||||
view.obj = Qnil;
|
||||
rb_raise(rb_eArgError, "unable to get a memory view");
|
||||
}
|
||||
n_samples = view.byte_size / view.item_size;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
n_samples = RARRAY_LEN(samples);
|
||||
} else if (rb_respond_to(samples, id_length)) {
|
||||
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
|
||||
} else {
|
||||
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
|
||||
}
|
||||
}
|
||||
float * c_samples = (float *)malloc(n_samples * sizeof(float));
|
||||
if (memory_view_available_p) {
|
||||
c_samples = (float *)view.data;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
|
||||
}
|
||||
} else {
|
||||
// FIXME: use rb_block_call
|
||||
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
// TODO: check if iter is exhausted and raise ArgumentError
|
||||
VALUE sample = rb_funcall(iter, id_next, 0);
|
||||
c_samples[i] = RFLOAT_VALUE(sample);
|
||||
}
|
||||
}
|
||||
}
|
||||
register_callbacks(rwp, &self);
|
||||
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
|
||||
if (0 == result) {
|
||||
return self;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Number of segments.
|
||||
*
|
||||
* call-seq:
|
||||
* full_n_segments -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_n_segments(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_full_n_segments(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
|
||||
*
|
||||
* call-seq:
|
||||
* full_lang_id -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_lang_id(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_full_lang_id(rw->context));
|
||||
}
|
||||
|
||||
static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment)
|
||||
{
|
||||
const int c_i_segment = NUM2INT(i_segment);
|
||||
if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
|
||||
rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
|
||||
}
|
||||
return c_i_segment;
|
||||
}
|
||||
|
||||
/*
|
||||
* Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
|
||||
*
|
||||
* full_get_segment_t0(3) # => 1668 (16680 ms)
|
||||
*
|
||||
* call-seq:
|
||||
* full_get_segment_t0(segment_index) -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
|
||||
return INT2NUM(t0);
|
||||
}
|
||||
|
||||
/*
|
||||
* End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
|
||||
*
|
||||
* full_get_segment_t1(3) # => 1668 (16680 ms)
|
||||
*
|
||||
* call-seq:
|
||||
* full_get_segment_t1(segment_index) -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
|
||||
return INT2NUM(t1);
|
||||
}
|
||||
|
||||
/*
|
||||
* Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
|
||||
*
|
||||
* full_get_segment_speacker_turn_next(3) # => true
|
||||
*
|
||||
* call-seq:
|
||||
* full_get_segment_speacker_turn_next(segment_index) -> bool
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
|
||||
return speaker_turn_next ? Qtrue : Qfalse;
|
||||
}
|
||||
|
||||
/*
|
||||
* Text of a segment indexed by +segment_index+.
|
||||
*
|
||||
* full_get_segment_text(3) # => "ask not what your country can do for you, ..."
|
||||
*
|
||||
* call-seq:
|
||||
* full_get_segment_text(segment_index) -> String
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
|
||||
return rb_str_new2(text);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* full_get_segment_no_speech_prob(segment_index) -> Float
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
|
||||
return DBL2NUM(no_speech_prob);
|
||||
}
|
||||
|
||||
// High level API
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
|
||||
{
|
||||
return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
|
||||
}
|
||||
|
||||
/*
|
||||
* Yields each Whisper::Segment:
|
||||
*
|
||||
* whisper.transcribe("path/to/audio.wav", params)
|
||||
* whisper.each_segment do |segment|
|
||||
* puts segment.text
|
||||
* end
|
||||
*
|
||||
* Returns an Enumerator if no block given:
|
||||
*
|
||||
* whisper.transcribe("path/to/audio.wav", params)
|
||||
* enum = whisper.each_segment
|
||||
* enum.to_a # => [#<Whisper::Segment>, ...]
|
||||
*
|
||||
* call-seq:
|
||||
* each_segment {|segment| ... }
|
||||
* each_segment -> Enumerator
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_each_segment(VALUE self)
|
||||
{
|
||||
if (!rb_block_given_p()) {
|
||||
const VALUE method_name = rb_funcall(self, id___method__, 0);
|
||||
return rb_funcall(self, id_to_enum, 1, method_name);
|
||||
}
|
||||
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(rw->context);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
rb_yield(rb_whisper_segment_initialize(self, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* model -> Whisper::Model
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_get_model(VALUE self)
|
||||
{
|
||||
return rb_whisper_model_initialize(self);
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_context(VALUE *mWhisper)
|
||||
{
|
||||
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cContext, ruby_whisper_allocate);
|
||||
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
|
||||
|
||||
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
|
||||
rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
|
||||
rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
|
||||
rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
|
||||
rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
|
||||
rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
|
||||
rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
|
||||
rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
|
||||
rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
|
||||
rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
|
||||
rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
|
||||
rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
|
||||
rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
|
||||
rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
|
||||
rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
|
||||
rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
|
||||
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
|
||||
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
|
||||
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
|
||||
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
|
||||
rb_define_method(cContext, "full", ruby_whisper_full, -1);
|
||||
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
|
||||
|
||||
// High leve
|
||||
rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
|
||||
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
|
||||
|
||||
rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
|
||||
}
|
52
bindings/ruby/ext/ruby_whisper_error.c
Normal file
52
bindings/ruby/ext/ruby_whisper_error.c
Normal file
@ -0,0 +1,52 @@
|
||||
#include <ruby.h>
|
||||
|
||||
extern VALUE eError;
|
||||
|
||||
VALUE ruby_whisper_error_initialize(VALUE self, VALUE code)
|
||||
{
|
||||
const int c_code = NUM2INT(code);
|
||||
const char *raw_message;
|
||||
switch (c_code) {
|
||||
case -2:
|
||||
raw_message = "failed to compute log mel spectrogram";
|
||||
break;
|
||||
case -3:
|
||||
raw_message = "failed to auto-detect language";
|
||||
break;
|
||||
case -4:
|
||||
raw_message = "too many decoders requested";
|
||||
break;
|
||||
case -5:
|
||||
raw_message = "audio_ctx is larger than the maximum allowed";
|
||||
break;
|
||||
case -6:
|
||||
raw_message = "failed to encode";
|
||||
break;
|
||||
case -7:
|
||||
raw_message = "whisper_kv_cache_init() failed for self-attention cache";
|
||||
break;
|
||||
case -8:
|
||||
raw_message = "failed to decode";
|
||||
break;
|
||||
case -9:
|
||||
raw_message = "failed to decode";
|
||||
break;
|
||||
default:
|
||||
raw_message = "unknown error";
|
||||
break;
|
||||
}
|
||||
const VALUE message = rb_str_new2(raw_message);
|
||||
rb_call_super(1, &message);
|
||||
rb_iv_set(self, "@code", code);
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_error(VALUE *mWhisper)
|
||||
{
|
||||
eError = rb_define_class_under(*mWhisper, "Error", rb_eStandardError);
|
||||
|
||||
rb_define_attr(eError, "code", true, false);
|
||||
rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
|
||||
}
|
210
bindings/ruby/ext/ruby_whisper_model.c
Normal file
210
bindings/ruby/ext/ruby_whisper_model.c
Normal file
@ -0,0 +1,210 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
extern VALUE cModel;
|
||||
|
||||
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
|
||||
rb_gc_mark(rwm->context);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_model_allocate(VALUE klass) {
|
||||
ruby_whisper_model *rwm;
|
||||
rwm = ALLOC(ruby_whisper_model);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
|
||||
}
|
||||
|
||||
VALUE rb_whisper_model_initialize(VALUE context) {
|
||||
ruby_whisper_model *rwm;
|
||||
const VALUE model = ruby_whisper_model_allocate(cModel);
|
||||
Data_Get_Struct(model, ruby_whisper_model, rwm);
|
||||
rwm->context = context;
|
||||
return model;
|
||||
};
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_vocab -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_vocab(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_vocab(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_audio_ctx -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_audio_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_audio_state -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_audio_state(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_state(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_audio_head -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_audio_head(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_head(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_audio_layer -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_audio_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_layer(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_text_ctx -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_text_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_ctx(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_text_state -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_text_state(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_state(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_text_head -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_text_head(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_head(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_text_layer -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_text_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_layer(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* n_mels -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_n_mels(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_mels(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* ftype -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_ftype(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_ftype(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* type -> String
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_model_type(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return rb_str_new2(whisper_model_type_readable(rw->context));
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_model(VALUE *mWhisper)
|
||||
{
|
||||
cModel = rb_define_class_under(*mWhisper, "Model", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
|
||||
rb_define_method(cModel, "n_vocab", ruby_whisper_model_n_vocab, 0);
|
||||
rb_define_method(cModel, "n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
|
||||
rb_define_method(cModel, "n_audio_state", ruby_whisper_model_n_audio_state, 0);
|
||||
rb_define_method(cModel, "n_audio_head", ruby_whisper_model_n_audio_head, 0);
|
||||
rb_define_method(cModel, "n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
|
||||
rb_define_method(cModel, "n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
|
||||
rb_define_method(cModel, "n_text_state", ruby_whisper_model_n_text_state, 0);
|
||||
rb_define_method(cModel, "n_text_head", ruby_whisper_model_n_text_head, 0);
|
||||
rb_define_method(cModel, "n_text_layer", ruby_whisper_model_n_text_layer, 0);
|
||||
rb_define_method(cModel, "n_mels", ruby_whisper_model_n_mels, 0);
|
||||
rb_define_method(cModel, "ftype", ruby_whisper_model_ftype, 0);
|
||||
rb_define_method(cModel, "type", ruby_whisper_model_type, 0);
|
||||
}
|
1077
bindings/ruby/ext/ruby_whisper_params.c
Normal file
1077
bindings/ruby/ext/ruby_whisper_params.c
Normal file
File diff suppressed because it is too large
Load Diff
123
bindings/ruby/ext/ruby_whisper_segment.c
Normal file
123
bindings/ruby/ext/ruby_whisper_segment.c
Normal file
@ -0,0 +1,123 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
extern VALUE cSegment;
|
||||
|
||||
static void
|
||||
rb_whisper_segment_mark(ruby_whisper_segment *rws)
|
||||
{
|
||||
rb_gc_mark(rws->context);
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_segment_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
rws = ALLOC(ruby_whisper_segment);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
|
||||
}
|
||||
|
||||
VALUE
|
||||
rb_whisper_segment_initialize(VALUE context, int index)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
const VALUE segment = ruby_whisper_segment_allocate(cSegment);
|
||||
Data_Get_Struct(segment, ruby_whisper_segment, rws);
|
||||
rws->context = context;
|
||||
rws->index = index;
|
||||
return segment;
|
||||
};
|
||||
|
||||
/*
|
||||
* Start time in milliseconds.
|
||||
*
|
||||
* call-seq:
|
||||
* start_time -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_segment_get_start_time(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
|
||||
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
|
||||
return INT2NUM(t0 * 10);
|
||||
}
|
||||
|
||||
/*
|
||||
* End time in milliseconds.
|
||||
*
|
||||
* call-seq:
|
||||
* end_time -> Integer
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_segment_get_end_time(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
|
||||
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
|
||||
return INT2NUM(t1 * 10);
|
||||
}
|
||||
|
||||
/*
|
||||
* Whether the next segment is predicted as a speaker turn.
|
||||
*
|
||||
* call-seq:
|
||||
* speaker_turn_next? -> bool
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_segment_get_speaker_turn_next(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* text -> String
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_segment_get_text(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
const char * text = whisper_full_get_segment_text(rw->context, rws->index);
|
||||
return rb_str_new2(text);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* no_speech_prob -> Float
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_segment_get_no_speech_prob(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
|
||||
{
|
||||
cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
|
||||
|
||||
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
|
||||
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
|
||||
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
|
||||
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
|
||||
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
|
||||
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
|
||||
}
|
159
bindings/ruby/ext/ruby_whisper_transcribe.cpp
Normal file
159
bindings/ruby/ext/ruby_whisper_transcribe.cpp
Normal file
@ -0,0 +1,159 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern ID id_to_s;
|
||||
extern ID id_call;
|
||||
|
||||
extern void
|
||||
register_callbacks(ruby_whisper_params * rwp, VALUE * self);
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
* can emit to a block results
|
||||
*
|
||||
* params = Whisper::Params.new
|
||||
* params.duration = 60_000
|
||||
* whisper.transcribe "path/to/audio.wav", params do |text|
|
||||
* puts text
|
||||
* end
|
||||
*
|
||||
* call-seq:
|
||||
* transcribe(path_to_audio, params) {|text| ...}
|
||||
**/
|
||||
VALUE
|
||||
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
VALUE wave_file_path, blk, params;
|
||||
|
||||
rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
|
||||
if (!rb_respond_to(wave_file_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
|
||||
}
|
||||
|
||||
std::string fname_inp = StringValueCStr(wave_file_path);
|
||||
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// WAV input - this is directly from main.cpp example
|
||||
{
|
||||
drwav wav;
|
||||
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||||
|
||||
if (fname_inp == "-") {
|
||||
{
|
||||
uint8_t buf[1024];
|
||||
while (true) {
|
||||
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||||
if (n == 0) {
|
||||
break;
|
||||
}
|
||||
wav_data.insert(wav_data.end(), buf, buf + n);
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return self;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
|
||||
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
|
||||
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
|
||||
return self;
|
||||
}
|
||||
|
||||
if (wav.bitsPerSample != 16) {
|
||||
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
|
||||
return self;
|
||||
}
|
||||
|
||||
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||||
|
||||
std::vector<int16_t> pcm16;
|
||||
pcm16.resize(n*wav.channels);
|
||||
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||||
drwav_uninit(&wav);
|
||||
|
||||
// convert to mono, float
|
||||
pcmf32.resize(n);
|
||||
if (wav.channels == 1) {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32[i] = float((int32_t)pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (rwp->diarize) {
|
||||
// convert to stereo, float
|
||||
pcmf32s.resize(2);
|
||||
|
||||
pcmf32s[0].resize(n);
|
||||
pcmf32s[1].resize(n);
|
||||
for (uint64_t i = 0; i < n; i++) {
|
||||
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||||
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
register_callbacks(rwp, &self);
|
||||
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments(rw->context);
|
||||
VALUE output = rb_str_new2("");
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(rw->context, i);
|
||||
output = rb_str_concat(output, rb_str_new2(text));
|
||||
}
|
||||
VALUE idCall = id_call;
|
||||
if (blk != Qnil) {
|
||||
rb_funcall(blk, idCall, 1, output);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
@ -1,2 +0,0 @@
|
||||
require "whisper.so"
|
||||
require "whisper/model/uri"
|
@ -1,157 +1,170 @@
|
||||
require "whisper.so"
|
||||
require "uri"
|
||||
require "net/http"
|
||||
require "time"
|
||||
require "pathname"
|
||||
require "io/console/size"
|
||||
|
||||
class Whisper::Model
|
||||
class URI
|
||||
def initialize(uri)
|
||||
@uri = URI(uri)
|
||||
end
|
||||
module Whisper
|
||||
class Model
|
||||
class URI
|
||||
def initialize(uri)
|
||||
@uri = URI(uri)
|
||||
end
|
||||
|
||||
def to_path
|
||||
cache
|
||||
cache_path.to_path
|
||||
end
|
||||
def to_path
|
||||
cache
|
||||
cache_path.to_path
|
||||
end
|
||||
|
||||
def clear_cache
|
||||
path = cache_path
|
||||
path.delete if path.exist?
|
||||
end
|
||||
def clear_cache
|
||||
path = cache_path
|
||||
path.delete if path.exist?
|
||||
end
|
||||
|
||||
private
|
||||
private
|
||||
|
||||
def cache_path
|
||||
base_cache_dir/@uri.host/@uri.path[1..]
|
||||
end
|
||||
def cache_path
|
||||
base_cache_dir/@uri.host/@uri.path[1..]
|
||||
end
|
||||
|
||||
def base_cache_dir
|
||||
base = case RUBY_PLATFORM
|
||||
when /mswin|mingw/
|
||||
ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
|
||||
when /darwin/
|
||||
Pathname(Dir.home)/"Library/Caches"
|
||||
else
|
||||
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
|
||||
end
|
||||
base/"whisper.cpp"
|
||||
end
|
||||
def base_cache_dir
|
||||
base = case RUBY_PLATFORM
|
||||
when /mswin|mingw/
|
||||
ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
|
||||
when /darwin/
|
||||
Pathname(Dir.home)/"Library/Caches"
|
||||
else
|
||||
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
|
||||
end
|
||||
base/"whisper.cpp"
|
||||
end
|
||||
|
||||
def cache
|
||||
path = cache_path
|
||||
headers = {}
|
||||
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
|
||||
request @uri, headers
|
||||
path
|
||||
end
|
||||
def cache
|
||||
path = cache_path
|
||||
headers = {}
|
||||
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
|
||||
request @uri, headers
|
||||
path
|
||||
end
|
||||
|
||||
def request(uri, headers)
|
||||
Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
|
||||
request = Net::HTTP::Get.new(uri, headers)
|
||||
http.request request do |response|
|
||||
case response
|
||||
when Net::HTTPNotModified
|
||||
def request(uri, headers)
|
||||
Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
|
||||
request = Net::HTTP::Get.new(uri, headers)
|
||||
http.request request do |response|
|
||||
case response
|
||||
when Net::HTTPNotModified
|
||||
# noop
|
||||
when Net::HTTPOK
|
||||
download response
|
||||
when Net::HTTPRedirection
|
||||
request URI(response["location"]), headers
|
||||
else
|
||||
return if headers.key?("if-modified-since") # Use cache file
|
||||
when Net::HTTPOK
|
||||
download response
|
||||
when Net::HTTPRedirection
|
||||
request URI(response["location"]), headers
|
||||
else
|
||||
return if headers.key?("if-modified-since") # Use cache file
|
||||
|
||||
raise "#{response.code} #{response.message}\n#{response.body}"
|
||||
raise "#{response.code} #{response.message}\n#{response.body}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def download(response)
|
||||
path = cache_path
|
||||
path.dirname.mkpath unless path.dirname.exist?
|
||||
downloading_path = Pathname("#{path}.downloading")
|
||||
size = response.content_length
|
||||
downloading_path.open "wb" do |file|
|
||||
downloaded = 0
|
||||
response.read_body do |chunk|
|
||||
file << chunk
|
||||
downloaded += chunk.bytesize
|
||||
show_progress downloaded, size
|
||||
rescue => err
|
||||
if cache_path.exist?
|
||||
warn err
|
||||
# Use cache file
|
||||
else
|
||||
raise
|
||||
end
|
||||
end
|
||||
downloading_path.rename path
|
||||
end
|
||||
|
||||
def show_progress(current, size)
|
||||
return unless $stderr.tty?
|
||||
return unless size
|
||||
|
||||
unless @prev
|
||||
@prev = Time.now
|
||||
$stderr.puts "Downloading #{@uri}"
|
||||
def download(response)
|
||||
path = cache_path
|
||||
path.dirname.mkpath unless path.dirname.exist?
|
||||
downloading_path = Pathname("#{path}.downloading")
|
||||
size = response.content_length
|
||||
downloading_path.open "wb" do |file|
|
||||
downloaded = 0
|
||||
response.read_body do |chunk|
|
||||
file << chunk
|
||||
downloaded += chunk.bytesize
|
||||
show_progress downloaded, size
|
||||
end
|
||||
$stderr.puts
|
||||
end
|
||||
downloading_path.rename path
|
||||
end
|
||||
|
||||
now = Time.now
|
||||
return if now - @prev < 1 && current < size
|
||||
def show_progress(current, size)
|
||||
progress_rate_available = size && $stderr.tty?
|
||||
|
||||
progress_width = 20
|
||||
progress = current.to_f / size
|
||||
arrow_length = progress * progress_width
|
||||
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
|
||||
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
|
||||
padding = ' ' * ($stderr.winsize[1] - line.size)
|
||||
$stderr.print "\r#{line}#{padding}"
|
||||
$stderr.puts if current >= size
|
||||
@prev = now
|
||||
unless @prev
|
||||
@prev = Time.now
|
||||
$stderr.puts "Downloading #{@uri} to #{cache_path}"
|
||||
end
|
||||
|
||||
now = Time.now
|
||||
|
||||
if progress_rate_available
|
||||
return if now - @prev < 1 && current < size
|
||||
|
||||
progress_width = 20
|
||||
progress = current.to_f / size
|
||||
arrow_length = progress * progress_width
|
||||
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
|
||||
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
|
||||
padding = ' ' * ($stderr.winsize[1] - line.size)
|
||||
$stderr.print "\r#{line}#{padding}"
|
||||
else
|
||||
return if now - @prev < 1
|
||||
|
||||
$stderr.print "."
|
||||
end
|
||||
@prev = now
|
||||
end
|
||||
|
||||
def format_bytesize(bytesize)
|
||||
return "0.0 B" if bytesize.zero?
|
||||
|
||||
units = %w[B KiB MiB GiB TiB]
|
||||
exp = (Math.log(bytesize) / Math.log(1024)).to_i
|
||||
format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
|
||||
end
|
||||
end
|
||||
|
||||
def format_bytesize(bytesize)
|
||||
return "0.0 B" if bytesize.zero?
|
||||
@pre_converted_models = %w[
|
||||
tiny
|
||||
tiny.en
|
||||
tiny-q5_1
|
||||
tiny.en-q5_1
|
||||
tiny-q8_0
|
||||
base
|
||||
base.en
|
||||
base-q5_1
|
||||
base.en-q5_1
|
||||
base-q8_0
|
||||
small
|
||||
small.en
|
||||
small.en-tdrz
|
||||
small-q5_1
|
||||
small.en-q5_1
|
||||
small-q8_0
|
||||
medium
|
||||
medium.en
|
||||
medium-q5_0
|
||||
medium.en-q5_0
|
||||
medium-q8_0
|
||||
large-v1
|
||||
large-v2
|
||||
large-v2-q5_0
|
||||
large-v2-q8_0
|
||||
large-v3
|
||||
large-v3-q5_0
|
||||
large-v3-turbo
|
||||
large-v3-turbo-q5_0
|
||||
large-v3-turbo-q8_0
|
||||
].each_with_object({}) {|name, models|
|
||||
models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
||||
}
|
||||
|
||||
units = %w[B KiB MiB GiB TiB]
|
||||
exp = (Math.log(bytesize) / Math.log(1024)).to_i
|
||||
format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
|
||||
class << self
|
||||
attr_reader :pre_converted_models
|
||||
end
|
||||
end
|
||||
|
||||
@pre_converted_models = {}
|
||||
%w[
|
||||
tiny
|
||||
tiny.en
|
||||
tiny-q5_1
|
||||
tiny.en-q5_1
|
||||
tiny-q8_0
|
||||
base
|
||||
base.en
|
||||
base-q5_1
|
||||
base.en-q5_1
|
||||
base-q8_0
|
||||
small
|
||||
small.en
|
||||
small.en-tdrz
|
||||
small-q5_1
|
||||
small.en-q5_1
|
||||
small-q8_0
|
||||
medium
|
||||
medium.en
|
||||
medium-q5_0
|
||||
medium.en-q5_0
|
||||
medium-q8_0
|
||||
large-v1
|
||||
large-v2
|
||||
large-v2-q5_0
|
||||
large-v2-q8_0
|
||||
large-v3
|
||||
large-v3-q5_0
|
||||
large-v3-turbo
|
||||
large-v3-turbo-q5_0
|
||||
large-v3-turbo-q8_0
|
||||
].each do |name|
|
||||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
class << self
|
||||
attr_reader :pre_converted_models
|
||||
end
|
||||
end
|
||||
|
189
bindings/ruby/sig/whisper.rbs
Normal file
189
bindings/ruby/sig/whisper.rbs
Normal file
@ -0,0 +1,189 @@
|
||||
module Whisper
|
||||
interface _Samples
|
||||
def length: () -> Integer
|
||||
def each: { (Float) -> void } -> void
|
||||
end
|
||||
|
||||
type log_callback = ^(Integer level, String message, Object user_data) -> void
|
||||
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
|
||||
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
|
||||
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
|
||||
|
||||
LOG_LEVEL_NONE: Integer
|
||||
LOG_LEVEL_INFO: Integer
|
||||
LOG_LEVEL_WARN: Integer
|
||||
LOG_LEVEL_ERROR: Integer
|
||||
LOG_LEVEL_DEBUG: Integer
|
||||
LOG_LEVEL_CONT: Integer
|
||||
|
||||
def self.lang_max_id: () -> Integer
|
||||
def self.lang_id: (string name) -> Integer
|
||||
def self.lang_str: (Integer id) -> String
|
||||
def self.lang_str_full: (Integer id) -> String
|
||||
def self.log_set: (log_callback, Object? user_data) -> log_callback
|
||||
|
||||
class Context
|
||||
def self.new: (string | _ToPath | ::URI::HTTP) -> instance
|
||||
def transcribe: (string, Params) -> self
|
||||
| (string, Params) { (String) -> void } -> self
|
||||
def model_n_vocab: () -> Integer
|
||||
def model_n_audio_ctx: () -> Integer
|
||||
def model_n_audio_state: () -> Integer
|
||||
def model_n_text_head: () -> Integer
|
||||
def model_n_text_layer: () -> Integer
|
||||
def model_n_mels: () -> Integer
|
||||
def model_ftype: () -> Integer
|
||||
def model_type: () -> String
|
||||
def each_segment: { (Segment) -> void } -> void
|
||||
| () -> Enumerator[Segment]
|
||||
def model: () -> Model
|
||||
def full_get_segment: (Integer nth) -> Segment
|
||||
def full_n_segments: () -> Integer
|
||||
def full_lang_id: () -> Integer
|
||||
def full_get_segment_t0: (Integer) -> Integer
|
||||
def full_get_segment_t1: (Integer) -> Integer
|
||||
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
|
||||
def full_get_segment_text: (Integer) -> String
|
||||
def full_get_segment_no_speech_prob: (Integer) -> Float
|
||||
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer n_samples) -> self
|
||||
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
||||
end
|
||||
|
||||
class Params
|
||||
def self.new: (
|
||||
?language: string,
|
||||
?translate: boolish,
|
||||
?no_context: boolish,
|
||||
?single_segment: boolish,
|
||||
?print_special: boolish,
|
||||
?print_progress: boolish,
|
||||
?print_realtime: boolish,
|
||||
?print_timestamps: boolish,
|
||||
?suppress_blank: boolish,
|
||||
?suppress_nst: boolish,
|
||||
?token_timestamps: boolish,
|
||||
?split_on_word: boolish,
|
||||
?initial_prompt: string | nil,
|
||||
?diarize: boolish,
|
||||
?offset: Integer,
|
||||
?duration: Integer,
|
||||
?max_text_tokens: Integer,
|
||||
?temperature: Float,
|
||||
?max_initial_ts: Float,
|
||||
?length_penalty: Float,
|
||||
?temperature_inc: Float,
|
||||
?entropy_thold: Float,
|
||||
?logprob_thold: Float,
|
||||
?no_speech_thold: Float,
|
||||
?new_segment_callback: new_segment_callback,
|
||||
?new_segment_callback_user_data: Object,
|
||||
?progress_callback: progress_callback,
|
||||
?progress_callback_user_data: Object,
|
||||
?abort_callback: abort_callback,
|
||||
?abort_callback_user_data: Object
|
||||
) -> instance
|
||||
def language=: (String) -> String # TODO: Enumerate lang names
|
||||
def language: () -> String
|
||||
def translate=: (boolish) -> boolish
|
||||
def translate: () -> (true | false)
|
||||
def no_context=: (boolish) -> boolish
|
||||
def no_context: () -> (true | false)
|
||||
def single_segment=: (boolish) -> boolish
|
||||
def single_segment: () -> (true | false)
|
||||
def print_special=: (boolish) -> boolish
|
||||
def print_special: () -> (true | false)
|
||||
def print_progress=: (boolish) -> boolish
|
||||
def print_progress: () -> (true | false)
|
||||
def print_realtime=: (boolish) -> boolish
|
||||
def print_realtime: () -> (true | false)
|
||||
def print_timestamps=: (boolish) -> boolish
|
||||
def print_timestamps: () -> (true | false)
|
||||
def suppress_blank=: (boolish) -> boolish
|
||||
def suppress_blank: () -> (true | false)
|
||||
def suppress_nst=: (boolish) -> boolish
|
||||
def suppress_nst: () -> (true | false)
|
||||
def token_timestamps=: (boolish) -> boolish
|
||||
def token_timestamps: () -> (true | false)
|
||||
def split_on_word=: (boolish) -> boolish
|
||||
def split_on_word: () -> (true | false)
|
||||
def initial_prompt=: (_ToS) -> _ToS
|
||||
def initial_prompt: () -> (String | nil)
|
||||
def diarize=: (boolish) -> boolish
|
||||
def diarize: () -> (true | false)
|
||||
def offset=: (Integer) -> Integer
|
||||
def offset: () -> Integer
|
||||
def duration=: (Integer) -> Integer
|
||||
def duration: () -> Integer
|
||||
def max_text_tokens=: (Integer) -> Integer
|
||||
def max_text_tokens: () -> Integer
|
||||
def temperature=: (Float) -> Float
|
||||
def temperature: () -> Float
|
||||
def max_initial_ts=: (Float) -> Float
|
||||
def max_initial_ts: () -> Float
|
||||
def length_penalty=: (Float) -> Float
|
||||
def length_penalty: () -> Float
|
||||
def temperature_inc=: (Float) -> Float
|
||||
def temperature_inc: () -> Float
|
||||
def entropy_thold=: (Float) -> Float
|
||||
def entropy_thold: () -> Float
|
||||
def logprob_thold=: (Float) -> Float
|
||||
def logprob_thold: () -> Float
|
||||
def no_speech_thold=: (Float) -> Float
|
||||
def no_speech_thold: () -> Float
|
||||
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
|
||||
def new_segment_callback: () -> (new_segment_callback | nil)
|
||||
def new_segment_callback_user_data=: (Object) -> Object
|
||||
def new_segment_callback_user_data: () -> Object
|
||||
def progress_callback=: (progress_callback) -> progress_callback
|
||||
def progress_callback: () -> (progress_callback | nil)
|
||||
def progress_callback_user_data=: (Object) -> Object
|
||||
def progress_callback_user_data: () -> Object
|
||||
def abort_callback=: (abort_callback) -> abort_callback
|
||||
def abort_callback: () -> (abort_callback | nil)
|
||||
def abort_callback_user_data=: (Object) -> Object
|
||||
def abort_callback_user_data: () -> Object
|
||||
def on_new_segment: { (Segment) -> void } -> void
|
||||
def on_progress: { (Integer progress) -> void } -> void
|
||||
def abort_on: { (Object user_data) -> boolish } -> void
|
||||
end
|
||||
|
||||
class Model
|
||||
def self.pre_converted_models: () -> Hash[String, Model::URI]
|
||||
def self.new: () -> instance
|
||||
def n_vocab: () -> Integer
|
||||
def n_audio_ctx: () -> Integer
|
||||
def n_audio_state: () -> Integer
|
||||
def n_audio_head: () -> Integer
|
||||
def n_audio_layer: () -> Integer
|
||||
def n_text_ctx: () -> Integer
|
||||
def n_text_state: () -> Integer
|
||||
def n_text_head: () -> Integer
|
||||
def n_text_layer: () -> Integer
|
||||
def n_mels: () -> Integer
|
||||
def ftype: () -> Integer
|
||||
def type: () -> String
|
||||
|
||||
class URI
|
||||
def self.new: (string | ::URI::HTTP) -> self
|
||||
def to_path: -> String
|
||||
def clear_cache: -> void
|
||||
end
|
||||
end
|
||||
|
||||
class Segment
|
||||
def start_time: () -> Integer
|
||||
def end_time: () -> Integer
|
||||
def speaker_next_turn?: () -> (true | false)
|
||||
def text: () -> String
|
||||
def no_speech_prob: () -> Float
|
||||
end
|
||||
|
||||
class Error < StandardError
|
||||
attr_reader code: Integer
|
||||
|
||||
def self.new: (Integer code) -> instance
|
||||
end
|
||||
end
|
@ -4,4 +4,21 @@ require_relative "jfk_reader/jfk_reader"
|
||||
|
||||
class TestBase < Test::Unit::TestCase
|
||||
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
|
||||
|
||||
class << self
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
end
|
||||
|
@ -68,4 +68,42 @@ class TestModel < TestBase
|
||||
assert_path_exist path
|
||||
assert_equal 147964211, File.size(path)
|
||||
end
|
||||
|
||||
def test_uri_string
|
||||
path = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin"
|
||||
whisper = Whisper::Context.new(path)
|
||||
model = whisper.model
|
||||
|
||||
assert_equal 51864, model.n_vocab
|
||||
assert_equal 1500, model.n_audio_ctx
|
||||
assert_equal 512, model.n_audio_state
|
||||
assert_equal 8, model.n_audio_head
|
||||
assert_equal 6, model.n_audio_layer
|
||||
assert_equal 448, model.n_text_ctx
|
||||
assert_equal 512, model.n_text_state
|
||||
assert_equal 8, model.n_text_head
|
||||
assert_equal 6, model.n_text_layer
|
||||
assert_equal 80, model.n_mels
|
||||
assert_equal 1, model.ftype
|
||||
assert_equal "base", model.type
|
||||
end
|
||||
|
||||
def test_uri
|
||||
path = URI("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin")
|
||||
whisper = Whisper::Context.new(path)
|
||||
model = whisper.model
|
||||
|
||||
assert_equal 51864, model.n_vocab
|
||||
assert_equal 1500, model.n_audio_ctx
|
||||
assert_equal 512, model.n_audio_state
|
||||
assert_equal 8, model.n_audio_head
|
||||
assert_equal 6, model.n_audio_layer
|
||||
assert_equal 448, model.n_text_ctx
|
||||
assert_equal 512, model.n_text_state
|
||||
assert_equal 8, model.n_text_head
|
||||
assert_equal 6, model.n_text_layer
|
||||
assert_equal 80, model.n_mels
|
||||
assert_equal 1, model.ftype
|
||||
assert_equal "base", model.type
|
||||
end
|
||||
end
|
||||
|
@ -23,7 +23,7 @@ class TestPackage < TestBase
|
||||
version = match_data[2]
|
||||
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
|
||||
Dir.mktmpdir do |dir|
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
|
||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
|
||||
end
|
||||
end
|
||||
|
@ -1,6 +1,39 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestParams < TestBase
|
||||
PARAM_NAMES = [
|
||||
:language,
|
||||
:translate,
|
||||
:no_context,
|
||||
:single_segment,
|
||||
:print_special,
|
||||
:print_progress,
|
||||
:print_realtime,
|
||||
:print_timestamps,
|
||||
:suppress_blank,
|
||||
:suppress_nst,
|
||||
:token_timestamps,
|
||||
:split_on_word,
|
||||
:initial_prompt,
|
||||
:diarize,
|
||||
:offset,
|
||||
:duration,
|
||||
:max_text_tokens,
|
||||
:temperature,
|
||||
:max_initial_ts,
|
||||
:length_penalty,
|
||||
:temperature_inc,
|
||||
:entropy_thold,
|
||||
:logprob_thold,
|
||||
:no_speech_thold,
|
||||
:new_segment_callback,
|
||||
:new_segment_callback_user_data,
|
||||
:progress_callback,
|
||||
:progress_callback_user_data,
|
||||
:abort_callback,
|
||||
:abort_callback_user_data,
|
||||
]
|
||||
|
||||
def setup
|
||||
@params = Whisper::Params.new
|
||||
end
|
||||
@ -89,11 +122,11 @@ class TestParams < TestBase
|
||||
assert !@params.suppress_blank
|
||||
end
|
||||
|
||||
def test_suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = true
|
||||
assert @params.suppress_non_speech_tokens
|
||||
@params.suppress_non_speech_tokens = false
|
||||
assert !@params.suppress_non_speech_tokens
|
||||
def test_suppress_nst
|
||||
@params.suppress_nst = true
|
||||
assert @params.suppress_nst
|
||||
@params.suppress_nst = false
|
||||
assert !@params.suppress_nst
|
||||
end
|
||||
|
||||
def test_token_timestamps
|
||||
@ -157,4 +190,57 @@ class TestParams < TestBase
|
||||
@params.no_speech_thold = 0.2
|
||||
assert_in_delta 0.2, @params.no_speech_thold
|
||||
end
|
||||
|
||||
def test_new_with_kw_args
|
||||
params = Whisper::Params.new(language: "es")
|
||||
assert_equal "es", params.language
|
||||
assert_equal 1.0, params.max_initial_ts
|
||||
end
|
||||
|
||||
def test_new_with_kw_args_non_existent
|
||||
assert_raise ArgumentError do
|
||||
Whisper::Params.new(non_existent: "value")
|
||||
end
|
||||
end
|
||||
|
||||
def test_new_with_kw_args_wrong_type
|
||||
assert_raise TypeError do
|
||||
Whisper::Params.new(language: 3)
|
||||
end
|
||||
end
|
||||
|
||||
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
|
||||
def test_new_with_kw_args_default_values(param)
|
||||
default_value = @params.send(param)
|
||||
value = case [param, default_value]
|
||||
in [*, true | false]
|
||||
!default_value
|
||||
in [*, Integer | Float]
|
||||
default_value + 1
|
||||
in [:language, *]
|
||||
"es"
|
||||
in [:initial_prompt, *]
|
||||
"Initial prompt"
|
||||
in [/_callback\Z/, *]
|
||||
proc {}
|
||||
in [/_user_data\Z/, *]
|
||||
Object.new
|
||||
end
|
||||
params = Whisper::Params.new(param => value)
|
||||
if Float === value
|
||||
assert_in_delta value, params.send(param)
|
||||
else
|
||||
assert_equal value, params.send(param)
|
||||
end
|
||||
|
||||
PARAM_NAMES.reject {|name| name == param}.each do |name|
|
||||
expected = @params.send(name)
|
||||
actual = params.send(name)
|
||||
if Float === expected
|
||||
assert_in_delta expected, actual
|
||||
else
|
||||
assert_equal expected, actual
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -1,17 +1,6 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestSegment < TestBase
|
||||
class << self
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
end
|
||||
end
|
||||
|
||||
def test_iteration
|
||||
whisper.each_segment do |segment|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
@ -43,6 +32,14 @@ class TestSegment < TestBase
|
||||
end
|
||||
end
|
||||
|
||||
def test_no_speech_prob
|
||||
no_speech_prob = nil
|
||||
whisper.each_segment do |segment|
|
||||
no_speech_prob = segment.no_speech_prob
|
||||
end
|
||||
assert no_speech_prob > 0.0
|
||||
end
|
||||
|
||||
def test_on_new_segment
|
||||
params = Whisper::Params.new
|
||||
seg = nil
|
||||
@ -74,10 +71,4 @@ class TestSegment < TestBase
|
||||
end
|
||||
whisper.transcribe(AUDIO, params)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
end
|
||||
|
@ -21,21 +21,6 @@ class TestWhisper < TestBase
|
||||
end
|
||||
|
||||
sub_test_case "After transcription" do
|
||||
class << self
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
end
|
||||
end
|
||||
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
|
||||
def test_full_n_segments
|
||||
assert_equal 1, whisper.full_n_segments
|
||||
end
|
||||
@ -44,6 +29,12 @@ class TestWhisper < TestBase
|
||||
assert_equal 0, whisper.full_lang_id
|
||||
end
|
||||
|
||||
def test_full_get_segment
|
||||
segment = whisper.full_get_segment(0)
|
||||
assert_equal 0, segment.start_time
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
|
||||
end
|
||||
|
||||
def test_full_get_segment_t0
|
||||
assert_equal 0, whisper.full_get_segment_t0(0)
|
||||
assert_raise IndexError do
|
||||
@ -70,6 +61,12 @@ class TestWhisper < TestBase
|
||||
def test_full_get_segment_text
|
||||
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
|
||||
end
|
||||
|
||||
def test_full_get_segment_no_speech_prob
|
||||
prob = whisper.full_get_segment_no_speech_prob(0)
|
||||
assert prob > 0.0
|
||||
assert prob < 1.0
|
||||
end
|
||||
end
|
||||
|
||||
def test_lang_max_id
|
||||
|
28
close-issue.yml
Normal file
28
close-issue.yml
Normal file
@ -0,0 +1,28 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "42 0 * * *"
|
||||
|
||||
# Fine-grant permission
|
||||
# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap"
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
operations-per-run: 10000
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
@ -13,5 +13,4 @@ set_target_properties(${TARGET}
|
||||
PROPERTIES
|
||||
EXPORT_COMPILE_COMMANDS ON
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib"
|
||||
)
|
||||
|
@ -42,6 +42,8 @@ endif()
|
||||
if(MSVC)
|
||||
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
|
||||
else()
|
||||
execute_process(
|
||||
COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER}
|
||||
|
@ -330,6 +330,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
|
||||
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
|
||||
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
|
||||
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
|
||||
|
||||
Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
|
||||
std::vector<float> pcmf32_vec;
|
||||
@ -352,6 +353,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
params.audio_ctx = audio_ctx;
|
||||
params.pcmf32 = pcmf32_vec;
|
||||
params.comma_in_time = comma_in_time;
|
||||
params.max_len = max_len;
|
||||
|
||||
Napi::Function callback = info[1].As<Napi::Function>();
|
||||
Worker* worker = new Worker(callback, params);
|
||||
|
@ -18,6 +18,7 @@ const whisperParams = {
|
||||
translate: true,
|
||||
no_timestamps: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0,
|
||||
};
|
||||
|
||||
const arguments = process.argv.slice(2);
|
||||
|
@ -12,6 +12,11 @@
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
@ -43,6 +48,7 @@ struct whisper_params {
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
float no_speech_thold = 0.6f;
|
||||
float grammar_penalty = 100.0f;
|
||||
float temperature = 0.0f;
|
||||
float temperature_inc = 0.2f;
|
||||
@ -70,6 +76,7 @@ struct whisper_params {
|
||||
bool log_score = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
bool suppress_nst = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
@ -104,6 +111,11 @@ static char * whisper_param_turn_lowercase(char * in){
|
||||
return in;
|
||||
}
|
||||
|
||||
static char * requires_value_error(const std::string & arg) {
|
||||
fprintf(stderr, "error: argument %s requires value\n", arg.c_str());
|
||||
exit(0);
|
||||
}
|
||||
|
||||
static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
@ -122,21 +134,23 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
|
||||
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
|
||||
#define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg))
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
@ -148,30 +162,31 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = ARGV_NEXT; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(ARGV_NEXT); }
|
||||
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
|
||||
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||
else if ( arg == "--prompt") { params.prompt = ARGV_NEXT; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); }
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = ARGV_NEXT; }
|
||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; }
|
||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
|
||||
else if ( arg == "--suppress-regex") { params.suppress_regex = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -202,6 +217,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
|
||||
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
|
||||
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
|
||||
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
||||
@ -234,6 +250,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
|
||||
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
|
||||
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||
@ -904,6 +921,13 @@ static bool output_lrc(struct whisper_context * ctx, const char * fname, const w
|
||||
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
#if defined(_WIN32)
|
||||
// Set the console output code page to UTF-8, while command line arguments
|
||||
// are still encoded in the system's code page. In this way, we can print
|
||||
// non-ASCII characters to the console, and access files with non-ASCII paths.
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
#endif
|
||||
|
||||
whisper_params params;
|
||||
|
||||
// If the only argument starts with "@", read arguments line-by-line
|
||||
@ -1121,9 +1145,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
wparams.no_speech_thold = params.no_speech_thold;
|
||||
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
|
||||
wparams.suppress_nst = params.suppress_nst;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
|
@ -21,6 +21,12 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <chrono>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
@ -679,6 +685,10 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
#if defined(_WIN32)
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
#endif
|
||||
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
|
@ -24,6 +24,10 @@ int main(int argc, char** argv) {
|
||||
replacement_filename = "whisper-cli";
|
||||
}
|
||||
|
||||
if (filename == "main.exe") {
|
||||
replacement_filename = "whisper-cli.exe";
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
|
||||
fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str());
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
#include <set>
|
||||
#include <chrono>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
@ -181,7 +182,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
wparams.suppress_nst = true;
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
@ -225,7 +226,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
|
||||
wparams.prompt_tokens = cs.prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = cs.prompt_tokens.size();
|
||||
// TODO: properly expose as option
|
||||
wparams.suppress_non_speech_tokens = true;
|
||||
wparams.suppress_nst = true;
|
||||
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@ -61,6 +62,7 @@ struct whisper_params {
|
||||
float logprob_thold = -1.00f;
|
||||
float temperature = 0.00f;
|
||||
float temperature_inc = 0.20f;
|
||||
float no_speech_thold = 0.6f;
|
||||
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
@ -76,6 +78,7 @@ struct whisper_params {
|
||||
bool no_timestamps = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
bool suppress_nst = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
@ -134,7 +137,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
|
||||
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
|
||||
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
|
||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
|
||||
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
|
||||
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -179,6 +184,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
|
||||
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
|
||||
|
||||
// server params
|
||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
||||
@ -216,6 +224,24 @@ void check_ffmpeg_availibility() {
|
||||
}
|
||||
}
|
||||
|
||||
std::string generate_temp_filename(const std::string &prefix, const std::string &extension) {
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto now_time_t = std::chrono::system_clock::to_time_t(now);
|
||||
|
||||
static std::mt19937 rng{std::random_device{}()};
|
||||
std::uniform_int_distribution<long long> dist(0, 1e9);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << prefix
|
||||
<< "-"
|
||||
<< std::put_time(std::localtime(&now_time_t), "%Y%m%d-%H%M%S")
|
||||
<< "-"
|
||||
<< dist(rng)
|
||||
<< extension;
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
|
||||
std::ostringstream cmd_stream;
|
||||
std::string converted_filename_temp = temp_filename + "_temp.wav";
|
||||
@ -472,6 +498,14 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
||||
{
|
||||
params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content);
|
||||
}
|
||||
if (req.has_file("suppress_non_speech"))
|
||||
{
|
||||
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
|
||||
}
|
||||
if (req.has_file("suppress_nst"))
|
||||
{
|
||||
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -677,9 +711,7 @@ int main(int argc, char ** argv) {
|
||||
if (sparams.ffmpeg_converter) {
|
||||
// if file is not wav, convert to wav
|
||||
// write to temporary file
|
||||
//const std::string temp_filename_base = std::tmpnam(nullptr);
|
||||
const std::string temp_filename_base = "whisper-server-tmp"; // TODO: this is a hack, remove when the mutext is removed
|
||||
const std::string temp_filename = temp_filename_base + ".wav";
|
||||
const std::string temp_filename = generate_temp_filename("whisper-server", ".wav");
|
||||
std::ofstream temp_file{temp_filename, std::ios::binary};
|
||||
temp_file << audio_file.content;
|
||||
temp_file.close();
|
||||
@ -779,6 +811,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.temperature = params.temperature;
|
||||
wparams.no_speech_thold = params.no_speech_thold;
|
||||
wparams.temperature_inc = params.temperature_inc;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
@ -786,6 +819,8 @@ int main(int argc, char ** argv) {
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
|
||||
|
||||
wparams.suppress_nst = params.suppress_nst;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
// this callback is called on each new segment
|
||||
@ -929,7 +964,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// TODO compression_ratio and no_speech_prob are not implemented yet
|
||||
// segment["compression_ratio"] = 0;
|
||||
// segment["no_speech_prob"] = 0;
|
||||
segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i);
|
||||
|
||||
jres["segments"].push_back(segment);
|
||||
}
|
||||
|
@ -12,7 +12,12 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
@ -113,6 +118,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
#if defined(_WIN32)
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
#endif
|
||||
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
@ -157,6 +166,7 @@ int main(int argc, char ** argv) {
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
fprintf(stderr, "whisper_init_from_file_with_params ...\n");
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
|
||||
@ -166,6 +176,8 @@ int main(int argc, char ** argv) {
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
// print some info about the processing
|
||||
fprintf(stderr, "whisper_init_from_file_with_params ok\n");
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
|
@ -1,10 +1,26 @@
|
||||
if (WHISPER_SDL2)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
set(TARGET whisper-talk-llama)
|
||||
add_executable(${TARGET} talk-llama.cpp
|
||||
llama.cpp
|
||||
llama-vocab.cpp
|
||||
llama-adapter.cpp
|
||||
llama-arch.cpp
|
||||
llama-batch.cpp
|
||||
llama-chat.cpp
|
||||
llama-context.cpp
|
||||
llama-cparams.cpp
|
||||
llama-grammar.cpp
|
||||
llama-hparams.cpp
|
||||
llama-impl.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model.cpp
|
||||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
llama-vocab.cpp
|
||||
unicode.cpp
|
||||
unicode-data.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||
|
347
examples/talk-llama/llama-adapter.cpp
Normal file
347
examples/talk-llama/llama-adapter.cpp
Normal file
@ -0,0 +1,347 @@
|
||||
#include "llama-adapter.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
|
||||
// vec
|
||||
|
||||
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
||||
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return tensors[il];
|
||||
}
|
||||
|
||||
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
|
||||
ggml_tensor * layer_dir = tensor_for(il);
|
||||
if (layer_dir != nullptr) {
|
||||
cur = ggml_add(ctx, cur, layer_dir);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
bool llama_adapter_cvec::init(const llama_model & model) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT(tensors.empty());
|
||||
GGML_ASSERT(ctxs.empty());
|
||||
GGML_ASSERT(bufs.empty());
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
if (!ctx) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ctx_map[buft] = ctx;
|
||||
ctxs.emplace_back(ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
};
|
||||
|
||||
// make tensors
|
||||
tensors.reserve(hparams.n_layer);
|
||||
tensors.push_back(nullptr); // there's never a tensor for layer 0
|
||||
for (size_t il = 1; il < hparams.n_layer; il++) {
|
||||
ggml_backend_buffer_type_t buft = model.select_buft(il);
|
||||
ggml_context * ctx = ctx_for_buft(buft);
|
||||
if (!ctx) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
|
||||
return false;
|
||||
}
|
||||
ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
|
||||
tensors.push_back(tensor);
|
||||
}
|
||||
|
||||
// allocate tensors / buffers and zero
|
||||
bufs.reserve(ctx_map.size());
|
||||
for (auto it : ctx_map) {
|
||||
ggml_backend_buffer_type_t buft = it.first;
|
||||
ggml_context * ctx = it.second;
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (!buf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
bufs.emplace_back(buf);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t llama_adapter_cvec::apply(
|
||||
const llama_model & model,
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
int32_t il_start,
|
||||
int32_t il_end) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
if (data == nullptr) {
|
||||
// disable the current control vector (but leave allocated for later)
|
||||
layer_start = -1;
|
||||
layer_end = -1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (n_embd != (int) hparams.n_embd) {
|
||||
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (tensors.empty()) {
|
||||
if (!init(model)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
layer_start = il_start;
|
||||
layer_end = il_end;
|
||||
|
||||
for (size_t il = 1; il < hparams.n_layer; il++) {
|
||||
assert(tensors[il] != nullptr);
|
||||
|
||||
const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
|
||||
if (off + n_embd <= len) {
|
||||
ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il]));
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// lora
|
||||
|
||||
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
|
||||
const std::string name(w->name);
|
||||
|
||||
const auto pos = ab_map.find(name);
|
||||
if (pos != ab_map.end()) {
|
||||
return &pos->second;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
|
||||
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
||||
|
||||
ggml_context * ctx_init;
|
||||
struct gguf_init_params meta_gguf_params = {
|
||||
/* .no_alloc = */ true,
|
||||
/* .ctx = */ &ctx_init,
|
||||
};
|
||||
|
||||
gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) };
|
||||
if (!ctx_gguf) {
|
||||
throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
|
||||
}
|
||||
|
||||
ggml_context_ptr ctx { ctx_init };
|
||||
|
||||
// check metadata
|
||||
{
|
||||
auto get_kv_str = [&](const std::string & key) -> std::string {
|
||||
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
|
||||
};
|
||||
auto get_kv_f32 = [&](const std::string & key) -> float {
|
||||
int id = gguf_find_key(ctx_gguf.get(), key.c_str());
|
||||
return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
|
||||
};
|
||||
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
||||
|
||||
auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
|
||||
if (general_type != "adapter") {
|
||||
throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
|
||||
}
|
||||
|
||||
auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
|
||||
auto general_arch = llm_arch_from_string(general_arch_str);
|
||||
if (general_arch != model.arch) {
|
||||
throw std::runtime_error("model arch and LoRA arch mismatch");
|
||||
}
|
||||
|
||||
auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
|
||||
if (adapter_type != "lora") {
|
||||
throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
|
||||
}
|
||||
|
||||
adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
|
||||
}
|
||||
|
||||
int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
|
||||
|
||||
// contexts for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
// add a new context
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context * buft_ctx = ggml_init(params);
|
||||
if (!buft_ctx) {
|
||||
return nullptr;
|
||||
}
|
||||
ctx_map[buft] = buft_ctx;
|
||||
adapter.ctxs.emplace_back(buft_ctx);
|
||||
return buft_ctx;
|
||||
};
|
||||
return it->second;
|
||||
};
|
||||
|
||||
// bundle lora_a and lora_b into pairs
|
||||
std::map<std::string, llama_adapter_lora_weight> ab_map;
|
||||
auto str_endswith = [](const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
};
|
||||
|
||||
for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) {
|
||||
std::string name(cur->name);
|
||||
if (str_endswith(name, ".lora_a")) {
|
||||
replace_all(name, ".lora_a", "");
|
||||
if (ab_map.find(name) == ab_map.end()) {
|
||||
ab_map[name] = llama_adapter_lora_weight(cur, nullptr);
|
||||
} else {
|
||||
ab_map[name].a = cur;
|
||||
}
|
||||
} else if (str_endswith(name, ".lora_b")) {
|
||||
replace_all(name, ".lora_b", "");
|
||||
if (ab_map.find(name) == ab_map.end()) {
|
||||
ab_map[name] = llama_adapter_lora_weight(nullptr, cur);
|
||||
} else {
|
||||
ab_map[name].b = cur;
|
||||
}
|
||||
} else if (str_endswith(name, "_norm.weight")) {
|
||||
// TODO: add support for norm vector
|
||||
// for now, we don't really care because most adapters still work fine without it
|
||||
continue;
|
||||
} else {
|
||||
throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
|
||||
}
|
||||
}
|
||||
|
||||
// add tensors
|
||||
for (auto & it : ab_map) {
|
||||
const std::string & name = it.first;
|
||||
llama_adapter_lora_weight & w = it.second;
|
||||
bool is_token_embd = str_endswith(name, "token_embd.weight");
|
||||
|
||||
if (!w.a || !w.b) {
|
||||
throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
|
||||
}
|
||||
|
||||
// device buft and device ctx
|
||||
const auto * model_tensor = model.get_tensor(name.c_str());
|
||||
if (!model_tensor) {
|
||||
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
||||
}
|
||||
|
||||
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
|
||||
// validate tensor shape
|
||||
if (is_token_embd) {
|
||||
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
||||
if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) {
|
||||
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
|
||||
}
|
||||
} else {
|
||||
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
|
||||
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
|
||||
}
|
||||
if (w.a->ne[1] != w.b->ne[0]) {
|
||||
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
|
||||
}
|
||||
}
|
||||
|
||||
// save tensor to adapter
|
||||
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
||||
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
||||
ggml_set_name(tensor_a, w.a->name);
|
||||
ggml_set_name(tensor_b, w.b->name);
|
||||
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
|
||||
}
|
||||
|
||||
// allocate tensors / buffers and zero
|
||||
{
|
||||
adapter.ctxs.reserve(ctx_map.size());
|
||||
adapter.bufs.reserve(ctx_map.size());
|
||||
for (auto & it : ctx_map) {
|
||||
ggml_backend_buffer_type_t buft = it.first;
|
||||
ggml_context * ctx_dev = it.second;
|
||||
ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) };
|
||||
if (!buf) {
|
||||
throw std::runtime_error("failed to allocate buffer for lora adapter\n");
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0);
|
||||
adapter.bufs.emplace_back(std::move(buf));
|
||||
}
|
||||
}
|
||||
|
||||
// set tensor data
|
||||
{
|
||||
llama_file gguf_file(path_lora, "rb");
|
||||
std::vector<uint8_t> read_buf;
|
||||
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
|
||||
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
||||
size_t size = ggml_nbytes(orig);
|
||||
read_buf.resize(size);
|
||||
gguf_file.seek(offs, SEEK_SET);
|
||||
gguf_file.read_raw(read_buf.data(), size);
|
||||
ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
|
||||
};
|
||||
for (auto & it : adapter.ab_map) {
|
||||
auto orig = ab_map[it.first];
|
||||
auto dev = it.second;
|
||||
set_tensor(orig.a, dev.a);
|
||||
set_tensor(orig.b, dev.b);
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
||||
}
|
||||
|
||||
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
|
||||
struct llama_adapter_lora * adapter = new llama_adapter_lora();
|
||||
|
||||
try {
|
||||
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
||||
return adapter;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
||||
|
||||
delete adapter;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
|
||||
delete adapter;
|
||||
}
|
74
examples/talk-llama/llama-adapter.h
Normal file
74
examples/talk-llama/llama-adapter.h
Normal file
@ -0,0 +1,74 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// TODO: pimpl
|
||||
|
||||
//
|
||||
// llama_adapter_cvec
|
||||
//
|
||||
|
||||
struct llama_adapter_cvec {
|
||||
struct ggml_tensor * tensor_for(int il) const;
|
||||
|
||||
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
|
||||
|
||||
int32_t apply(
|
||||
const llama_model & model,
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
private:
|
||||
bool init(const llama_model & model);
|
||||
|
||||
int32_t layer_start = -1;
|
||||
int32_t layer_end = -1;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
};
|
||||
|
||||
//
|
||||
// llama_adapter_lora
|
||||
//
|
||||
|
||||
struct llama_adapter_lora_weight {
|
||||
struct ggml_tensor * a = nullptr;
|
||||
struct ggml_tensor * b = nullptr;
|
||||
|
||||
// get actual scale based on rank and alpha
|
||||
float get_scale(float alpha, float adapter_scale) const {
|
||||
const float rank = (float) b->ne[0];
|
||||
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
|
||||
return scale;
|
||||
}
|
||||
|
||||
llama_adapter_lora_weight() = default;
|
||||
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
|
||||
};
|
||||
|
||||
struct llama_adapter_lora {
|
||||
// map tensor name to lora_a_b
|
||||
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
float alpha;
|
||||
|
||||
llama_adapter_lora() = default;
|
||||
~llama_adapter_lora() = default;
|
||||
|
||||
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
|
||||
};
|
1492
examples/talk-llama/llama-arch.cpp
Normal file
1492
examples/talk-llama/llama-arch.cpp
Normal file
File diff suppressed because it is too large
Load Diff
402
examples/talk-llama/llama-arch.h
Normal file
402
examples/talk-llama/llama-arch.h
Normal file
@ -0,0 +1,402 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h" // ggml_op
|
||||
|
||||
#include <string>
|
||||
|
||||
//
|
||||
// gguf constants (sync with gguf.py)
|
||||
//
|
||||
|
||||
enum llm_arch {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_DECI,
|
||||
LLM_ARCH_FALCON,
|
||||
LLM_ARCH_BAICHUAN,
|
||||
LLM_ARCH_GROK,
|
||||
LLM_ARCH_GPT2,
|
||||
LLM_ARCH_GPTJ,
|
||||
LLM_ARCH_GPTNEOX,
|
||||
LLM_ARCH_MPT,
|
||||
LLM_ARCH_STARCODER,
|
||||
LLM_ARCH_REFACT,
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
LLM_ARCH_QWEN,
|
||||
LLM_ARCH_QWEN2,
|
||||
LLM_ARCH_QWEN2MOE,
|
||||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
LLM_ARCH_MINICPM,
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
LLM_ARCH_COMMAND_R,
|
||||
LLM_ARCH_COHERE2,
|
||||
LLM_ARCH_DBRX,
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_OLMO2,
|
||||
LLM_ARCH_OLMOE,
|
||||
LLM_ARCH_OPENELM,
|
||||
LLM_ARCH_ARCTIC,
|
||||
LLM_ARCH_DEEPSEEK,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_CHATGLM,
|
||||
LLM_ARCH_BITNET,
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
enum llm_kv {
|
||||
LLM_KV_GENERAL_TYPE,
|
||||
LLM_KV_GENERAL_ARCHITECTURE,
|
||||
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||
LLM_KV_GENERAL_ALIGNMENT,
|
||||
LLM_KV_GENERAL_NAME,
|
||||
LLM_KV_GENERAL_AUTHOR,
|
||||
LLM_KV_GENERAL_VERSION,
|
||||
LLM_KV_GENERAL_URL,
|
||||
LLM_KV_GENERAL_DESCRIPTION,
|
||||
LLM_KV_GENERAL_LICENSE,
|
||||
LLM_KV_GENERAL_SOURCE_URL,
|
||||
LLM_KV_GENERAL_SOURCE_HF_REPO,
|
||||
|
||||
LLM_KV_VOCAB_SIZE,
|
||||
LLM_KV_CONTEXT_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH,
|
||||
LLM_KV_FEATURES_LENGTH,
|
||||
LLM_KV_BLOCK_COUNT,
|
||||
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||
LLM_KV_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||
LLM_KV_EXPERT_COUNT,
|
||||
LLM_KV_EXPERT_USED_COUNT,
|
||||
LLM_KV_EXPERT_SHARED_COUNT,
|
||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
LLM_KV_DECODER_START_TOKEN_ID,
|
||||
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
||||
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
||||
LLM_KV_SWIN_NORM,
|
||||
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
||||
LLM_KV_TIME_MIX_EXTRA_DIM,
|
||||
LLM_KV_TIME_DECAY_EXTRA_DIM,
|
||||
LLM_KV_RESIDUAL_SCALE,
|
||||
LLM_KV_EMBEDDING_SCALE,
|
||||
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
|
||||
LLM_KV_ATTENTION_CLAMP_KQV,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH,
|
||||
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
||||
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
||||
LLM_KV_ATTENTION_GROUPNORM_EPS,
|
||||
LLM_KV_ATTENTION_GROUPNORM_GROUPS,
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
LLM_KV_ROPE_SCALE_LINEAR,
|
||||
LLM_KV_ROPE_SCALING_TYPE,
|
||||
LLM_KV_ROPE_SCALING_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
||||
|
||||
LLM_KV_SPLIT_NO,
|
||||
LLM_KV_SPLIT_COUNT,
|
||||
LLM_KV_SPLIT_TENSORS_COUNT,
|
||||
|
||||
LLM_KV_SSM_INNER_SIZE,
|
||||
LLM_KV_SSM_CONV_KERNEL,
|
||||
LLM_KV_SSM_STATE_SIZE,
|
||||
LLM_KV_SSM_TIME_STEP_RANK,
|
||||
LLM_KV_SSM_DT_B_C_RMS,
|
||||
|
||||
LLM_KV_WKV_HEAD_SIZE,
|
||||
|
||||
LLM_KV_TOKENIZER_MODEL,
|
||||
LLM_KV_TOKENIZER_PRE,
|
||||
LLM_KV_TOKENIZER_LIST,
|
||||
LLM_KV_TOKENIZER_TOKEN_TYPE,
|
||||
LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
|
||||
LLM_KV_TOKENIZER_SCORES,
|
||||
LLM_KV_TOKENIZER_MERGES,
|
||||
LLM_KV_TOKENIZER_BOS_ID,
|
||||
LLM_KV_TOKENIZER_EOS_ID,
|
||||
LLM_KV_TOKENIZER_EOT_ID,
|
||||
LLM_KV_TOKENIZER_EOM_ID,
|
||||
LLM_KV_TOKENIZER_UNK_ID,
|
||||
LLM_KV_TOKENIZER_SEP_ID,
|
||||
LLM_KV_TOKENIZER_PAD_ID,
|
||||
LLM_KV_TOKENIZER_CLS_ID,
|
||||
LLM_KV_TOKENIZER_MASK_ID,
|
||||
LLM_KV_TOKENIZER_ADD_BOS,
|
||||
LLM_KV_TOKENIZER_ADD_EOS,
|
||||
LLM_KV_TOKENIZER_ADD_PREFIX,
|
||||
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
||||
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
||||
LLM_KV_TOKENIZER_HF_JSON,
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
LLM_KV_TOKENIZER_FIM_PAD_ID,
|
||||
LLM_KV_TOKENIZER_FIM_REP_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SEP_ID,
|
||||
|
||||
LLM_KV_ADAPTER_TYPE,
|
||||
LLM_KV_ADAPTER_LORA_ALPHA,
|
||||
|
||||
LLM_KV_POSNET_EMBEDDING_LENGTH,
|
||||
LLM_KV_POSNET_BLOCK_COUNT,
|
||||
|
||||
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
||||
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
LLM_KV_TOKENIZER_MIDDLE_ID,
|
||||
};
|
||||
|
||||
enum llm_tensor {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_TOKEN_EMBD_NORM,
|
||||
LLM_TENSOR_TOKEN_TYPES,
|
||||
LLM_TENSOR_POS_EMBD,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_ROPE_FREQS,
|
||||
LLM_TENSOR_ROPE_FACTORS_LONG,
|
||||
LLM_TENSOR_ROPE_FACTORS_SHORT,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_QKV,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_NORM_2,
|
||||
LLM_TENSOR_ATTN_OUT_NORM,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_POST_NORM,
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_ACT,
|
||||
LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
|
||||
LLM_TENSOR_FFN_GATE_EXP,
|
||||
LLM_TENSOR_FFN_UP_EXP,
|
||||
LLM_TENSOR_FFN_NORM_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
|
||||
LLM_TENSOR_FFN_GATE_EXPS,
|
||||
LLM_TENSOR_FFN_UP_EXPS,
|
||||
LLM_TENSOR_FFN_DOWN_SHEXP,
|
||||
LLM_TENSOR_FFN_GATE_SHEXP,
|
||||
LLM_TENSOR_FFN_UP_SHEXP,
|
||||
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
LLM_TENSOR_SSM_IN,
|
||||
LLM_TENSOR_SSM_CONV1D,
|
||||
LLM_TENSOR_SSM_X,
|
||||
LLM_TENSOR_SSM_DT,
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_TIME_MIX_W1,
|
||||
LLM_TENSOR_TIME_MIX_W2,
|
||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||
LLM_TENSOR_TIME_MIX_LERP_V,
|
||||
LLM_TENSOR_TIME_MIX_LERP_R,
|
||||
LLM_TENSOR_TIME_MIX_LERP_G,
|
||||
LLM_TENSOR_TIME_MIX_LERP_FUSED,
|
||||
LLM_TENSOR_TIME_MIX_FIRST,
|
||||
LLM_TENSOR_TIME_MIX_DECAY,
|
||||
LLM_TENSOR_TIME_MIX_DECAY_W1,
|
||||
LLM_TENSOR_TIME_MIX_DECAY_W2,
|
||||
LLM_TENSOR_TIME_MIX_KEY,
|
||||
LLM_TENSOR_TIME_MIX_VALUE,
|
||||
LLM_TENSOR_TIME_MIX_RECEPTANCE,
|
||||
LLM_TENSOR_TIME_MIX_GATE,
|
||||
LLM_TENSOR_TIME_MIX_LN,
|
||||
LLM_TENSOR_TIME_MIX_OUTPUT,
|
||||
LLM_TENSOR_CHANNEL_MIX_LERP_K,
|
||||
LLM_TENSOR_CHANNEL_MIX_LERP_R,
|
||||
LLM_TENSOR_CHANNEL_MIX_KEY,
|
||||
LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
|
||||
LLM_TENSOR_CHANNEL_MIX_VALUE,
|
||||
LLM_TENSOR_ATTN_Q_A,
|
||||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
LLM_TENSOR_FFN_SUB_NORM,
|
||||
LLM_TENSOR_DEC_ATTN_NORM,
|
||||
LLM_TENSOR_DEC_ATTN_Q,
|
||||
LLM_TENSOR_DEC_ATTN_K,
|
||||
LLM_TENSOR_DEC_ATTN_V,
|
||||
LLM_TENSOR_DEC_ATTN_OUT,
|
||||
LLM_TENSOR_DEC_ATTN_REL_B,
|
||||
LLM_TENSOR_DEC_CROSS_ATTN_NORM,
|
||||
LLM_TENSOR_DEC_CROSS_ATTN_Q,
|
||||
LLM_TENSOR_DEC_CROSS_ATTN_K,
|
||||
LLM_TENSOR_DEC_CROSS_ATTN_V,
|
||||
LLM_TENSOR_DEC_CROSS_ATTN_OUT,
|
||||
LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
|
||||
LLM_TENSOR_DEC_FFN_NORM,
|
||||
LLM_TENSOR_DEC_FFN_GATE,
|
||||
LLM_TENSOR_DEC_FFN_DOWN,
|
||||
LLM_TENSOR_DEC_FFN_UP,
|
||||
LLM_TENSOR_DEC_OUTPUT_NORM,
|
||||
LLM_TENSOR_ENC_ATTN_NORM,
|
||||
LLM_TENSOR_ENC_ATTN_Q,
|
||||
LLM_TENSOR_ENC_ATTN_K,
|
||||
LLM_TENSOR_ENC_ATTN_V,
|
||||
LLM_TENSOR_ENC_ATTN_OUT,
|
||||
LLM_TENSOR_ENC_ATTN_REL_B,
|
||||
LLM_TENSOR_ENC_FFN_NORM,
|
||||
LLM_TENSOR_ENC_FFN_GATE,
|
||||
LLM_TENSOR_ENC_FFN_DOWN,
|
||||
LLM_TENSOR_ENC_FFN_UP,
|
||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_CONV1D,
|
||||
LLM_TENSOR_CONVNEXT_DW,
|
||||
LLM_TENSOR_CONVNEXT_NORM,
|
||||
LLM_TENSOR_CONVNEXT_PW1,
|
||||
LLM_TENSOR_CONVNEXT_PW2,
|
||||
LLM_TENSOR_CONVNEXT_GAMMA,
|
||||
LLM_TENSOR_POS_NET_CONV1,
|
||||
LLM_TENSOR_POS_NET_CONV2,
|
||||
LLM_TENSOR_POS_NET_NORM,
|
||||
LLM_TENSOR_POS_NET_NORM1,
|
||||
LLM_TENSOR_POS_NET_NORM2,
|
||||
LLM_TENSOR_POS_NET_ATTN_NORM,
|
||||
LLM_TENSOR_POS_NET_ATTN_Q,
|
||||
LLM_TENSOR_POS_NET_ATTN_K,
|
||||
LLM_TENSOR_POS_NET_ATTN_V,
|
||||
LLM_TENSOR_POS_NET_ATTN_OUT,
|
||||
};
|
||||
|
||||
enum llm_tensor_layer {
|
||||
LLM_TENSOR_LAYER_INPUT,
|
||||
LLM_TENSOR_LAYER_REPEATING,
|
||||
LLM_TENSOR_LAYER_OUTPUT,
|
||||
};
|
||||
|
||||
struct LLM_KV {
|
||||
LLM_KV(llm_arch arch, const char * suffix = nullptr);
|
||||
|
||||
llm_arch arch;
|
||||
const char * suffix;
|
||||
|
||||
std::string operator()(llm_kv kv) const;
|
||||
};
|
||||
|
||||
// helper to handle gguf constants
|
||||
// usage:
|
||||
//
|
||||
// const auto tn = LLM_TN(LLM_ARCH_LLAMA);
|
||||
//
|
||||
// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output"
|
||||
// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias"
|
||||
// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight"
|
||||
//
|
||||
struct LLM_TN_IMPL {
|
||||
const llm_arch arch;
|
||||
const llm_tensor tensor;
|
||||
const char * const suffix;
|
||||
const int bid;
|
||||
const int xid;
|
||||
|
||||
std::string str() const;
|
||||
|
||||
operator std::string() const {
|
||||
return str();
|
||||
}
|
||||
|
||||
friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) {
|
||||
return str == tn.str();
|
||||
}
|
||||
|
||||
friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) {
|
||||
return str != tn.str();
|
||||
}
|
||||
};
|
||||
|
||||
struct LLM_TN {
|
||||
LLM_TN(llm_arch arch) : arch(arch) {}
|
||||
|
||||
llm_arch arch;
|
||||
|
||||
LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, suffix, bid, xid };
|
||||
}
|
||||
|
||||
LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const {
|
||||
return { arch, tensor, nullptr, bid, xid };
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct llm_tensor_info {
|
||||
llm_tensor_layer layer;
|
||||
ggml_op op;
|
||||
};
|
||||
|
||||
const char * llm_arch_name(llm_arch arch);
|
||||
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
368
examples/talk-llama/llama-batch.cpp
Normal file
368
examples/talk-llama/llama-batch.cpp
Normal file
@ -0,0 +1,368 @@
|
||||
#include "llama-batch.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
|
||||
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
||||
// clear empty sequences
|
||||
// the previous ubatch is assumed to be gone,
|
||||
// so nothing should refer to values in these sequences anymore.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
if (seq[i].length == 0) {
|
||||
seq.pop_back();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
ubatch_token.resize(!has_embd ? n_ubatch : 0);
|
||||
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
||||
ubatch_pos.resize(n_ubatch);
|
||||
ubatch_n_seq_id.resize(n_ubatch);
|
||||
ubatch_seq_id.resize(n_ubatch);
|
||||
ubatch_output.resize(n_ubatch);
|
||||
llama_ubatch ubatch = {
|
||||
/*equal_seqs =*/ true,
|
||||
/*n_tokens =*/ 0,
|
||||
/*n_seq_tokens =*/ 0,
|
||||
/*n_seqs =*/ 0,
|
||||
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
|
||||
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
|
||||
/*pos =*/ ubatch_pos.data(),
|
||||
/*n_seq_id =*/ ubatch_n_seq_id.data(),
|
||||
/*seq_id =*/ ubatch_seq_id.data(),
|
||||
/*output =*/ ubatch_output.data(),
|
||||
};
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
||||
GGML_ASSERT(batch != nullptr);
|
||||
GGML_ASSERT(length <= seq.length);
|
||||
// Can only add sequences of equal lengths to a batch,
|
||||
// otherwise it isn't clear to which sequence a token belongs
|
||||
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
||||
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
||||
// NOTE: loops are separated for cache-friendliness
|
||||
if (batch->token) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.token = batch->token + seq.offset;
|
||||
}
|
||||
} else {
|
||||
ubatch.token = nullptr;
|
||||
}
|
||||
if (batch->embd) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
memcpy(
|
||||
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
|
||||
batch->embd + (n_embd * ids[seq.offset + i]),
|
||||
n_embd * sizeof(float)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
||||
}
|
||||
} else {
|
||||
ubatch.embd = nullptr;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.pos = batch->pos + seq.offset;
|
||||
}
|
||||
if (ubatch.equal_seqs) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
||||
if (seq.seq_id) {
|
||||
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
if (batch->n_seq_id) {
|
||||
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
||||
}
|
||||
}
|
||||
if (batch->seq_id) {
|
||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||
}
|
||||
}
|
||||
if (logits_all) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||
out_ids.push_back(ids[seq.offset + i]);
|
||||
}
|
||||
} else if (batch->logits) {
|
||||
if (ubatch.equal_seqs) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_output = batch->logits[id];
|
||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||
if (is_output) { out_ids.push_back(id); }
|
||||
}
|
||||
} else {
|
||||
// simple split
|
||||
ubatch.output = batch->logits + seq.offset;
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// only get last output
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
size_t id = ids[seq.offset + i];
|
||||
int8_t is_last = id == ids.size() - 1;
|
||||
ubatch.output[ubatch.n_tokens + i] = is_last;
|
||||
if (is_last) { out_ids.push_back(id); }
|
||||
}
|
||||
}
|
||||
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
||||
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
||||
}
|
||||
ubatch.n_tokens += length;
|
||||
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
||||
seq.offset += length;
|
||||
seq.length -= length;
|
||||
n_tokens -= length;
|
||||
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
||||
}
|
||||
|
||||
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
ubatch.equal_seqs = false;
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
size_t length = 0;
|
||||
size_t n_tokens_in_ubatch = 0;
|
||||
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
||||
// smallest first, because it's easier to split this way;
|
||||
// starting from the end to pop in constant time.
|
||||
for (size_t i = seq.size(); i-- > 0;) {
|
||||
llama_sbatch_seq & s = seq[i];
|
||||
GGML_ASSERT(s.length > 0);
|
||||
if (length == 0) {
|
||||
length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
}
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
n_tokens_in_ubatch += length;
|
||||
// shared prompts can't be mixed with any of their sequences,
|
||||
// so it's safer to compute them in their own ubatch
|
||||
if (s.n_seq_id > 1) { break; }
|
||||
// stop when there isn't enough space for another sequence
|
||||
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
||||
}
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
||||
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
||||
if (!seq.empty()) {
|
||||
llama_sbatch_seq & s = seq[seq.size() - 1];
|
||||
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
||||
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
||||
add_seq_to_ubatch(ubatch, s, length);
|
||||
}
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
this->logits_all = logits_all;
|
||||
|
||||
n_tokens = batch.n_tokens;
|
||||
ids.resize(n_tokens);
|
||||
out_ids.clear();
|
||||
// TODO: reserve out_ids and seq
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
ids[i] = i;
|
||||
}
|
||||
if (simple_split) {
|
||||
seq.resize(1);
|
||||
llama_sbatch_seq & s = seq[0];
|
||||
s.n_seq_id = 0;
|
||||
s.seq_id = nullptr;
|
||||
s.offset = 0;
|
||||
s.length = n_tokens;
|
||||
return;
|
||||
}
|
||||
std::sort(ids.begin(), ids.end(),
|
||||
[&batch](size_t a, size_t b) {
|
||||
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
||||
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
||||
// sort by seq_id, then by pos
|
||||
if (n_seq_a == n_seq_b) {
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < n_seq_a; ++i) {
|
||||
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
||||
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
||||
// smaller seq_ids go first
|
||||
if (seq_id_a != seq_id_b) {
|
||||
return seq_id_a < seq_id_b;
|
||||
}
|
||||
}
|
||||
}
|
||||
// when all else is equal, sort by pos
|
||||
if (batch.pos) {
|
||||
return batch.pos[a] < batch.pos[b];
|
||||
}
|
||||
// no pos, sort by id
|
||||
return a < b;
|
||||
}
|
||||
// shared prompts go first
|
||||
return n_seq_a > n_seq_b;
|
||||
}
|
||||
);
|
||||
// init seq
|
||||
llama_sbatch_seq * last_seq = nullptr;
|
||||
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
const size_t bi = ids[i];
|
||||
const int32_t n_seqs = batch.n_seq_id[bi];
|
||||
llama_seq_id * seq_ids = batch.seq_id[bi];
|
||||
if (last_seq != nullptr) {
|
||||
bool same = n_seqs == last_seq->n_seq_id;
|
||||
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
||||
if (seq_ids[j] != last_seq->seq_id[j]) {
|
||||
same = false;
|
||||
}
|
||||
}
|
||||
if (same) {
|
||||
last_seq->length += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
||||
seq.push_back(new_seq);
|
||||
last_seq = &seq.back();
|
||||
}
|
||||
// keep shared prompts first at the end, then sort by length descending.
|
||||
std::sort(seq.begin(), seq.end(),
|
||||
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
||||
if (a.n_seq_id == b.n_seq_id) {
|
||||
return a.length > b.length;
|
||||
}
|
||||
return a.n_seq_id < b.n_seq_id;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
pos[i] = i + p0;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch batch = {
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||
} else {
|
||||
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
}
|
||||
|
||||
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
|
||||
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
|
||||
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
|
||||
for (int i = 0; i < n_tokens_alloc; ++i) {
|
||||
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||
}
|
||||
batch.seq_id[n_tokens_alloc] = nullptr;
|
||||
|
||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch batch) {
|
||||
if (batch.token) free(batch.token);
|
||||
if (batch.embd) free(batch.embd);
|
||||
if (batch.pos) free(batch.pos);
|
||||
if (batch.n_seq_id) free(batch.n_seq_id);
|
||||
if (batch.seq_id) {
|
||||
for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
|
||||
free(batch.seq_id[i]);
|
||||
}
|
||||
free(batch.seq_id);
|
||||
}
|
||||
if (batch.logits) free(batch.logits);
|
||||
}
|
88
examples/talk-llama/llama-batch.h
Normal file
88
examples/talk-llama/llama-batch.h
Normal file
@ -0,0 +1,88 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
struct llama_ubatch {
|
||||
bool equal_seqs;
|
||||
// TODO: whole_seqs for embeddings?
|
||||
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_seq_tokens; // tokens per sequence
|
||||
uint32_t n_seqs;
|
||||
|
||||
llama_token * token; // [n_tokens]
|
||||
float * embd; // [n_embd, n_tokens]
|
||||
llama_pos * pos; // [n_tokens]
|
||||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int8_t * output; // [n_tokens]
|
||||
};
|
||||
|
||||
struct llama_sbatch_seq {
|
||||
int32_t n_seq_id;
|
||||
|
||||
llama_seq_id * seq_id;
|
||||
|
||||
size_t offset;
|
||||
size_t length;
|
||||
};
|
||||
|
||||
// sequence-length-aware batch splitting
|
||||
struct llama_sbatch {
|
||||
// tokens left in this batch
|
||||
size_t n_tokens;
|
||||
|
||||
size_t n_embd;
|
||||
|
||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||
|
||||
// sorted indices into the batch
|
||||
std::vector<size_t> ids;
|
||||
// batch indices of the output
|
||||
std::vector<size_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
const llama_batch * batch = nullptr;
|
||||
|
||||
// buffers for the ubatch
|
||||
std::vector<llama_token> ubatch_token;
|
||||
std::vector<float> ubatch_embd;
|
||||
std::vector<llama_pos> ubatch_pos;
|
||||
std::vector<int32_t> ubatch_n_seq_id;
|
||||
std::vector<llama_seq_id *> ubatch_seq_id;
|
||||
std::vector<int8_t> ubatch_output;
|
||||
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
||||
|
||||
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
|
||||
|
||||
// simple split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_simple(size_t n_ubatch);
|
||||
|
||||
// make batches of equal-length sequences
|
||||
llama_ubatch split_equal(size_t n_ubatch);
|
||||
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> logits;
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
};
|
587
examples/talk-llama/llama-chat.cpp
Normal file
587
examples/talk-llama/llama-chat.cpp
Normal file
@ -0,0 +1,587 @@
|
||||
#include "llama-chat.h"
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
|
||||
#if __cplusplus >= 202000L
|
||||
#define LU8(x) (const char*)(u8##x)
|
||||
#else
|
||||
#define LU8(x) u8##x
|
||||
#endif
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
while (start < end && isspace(str[start])) {
|
||||
start += 1;
|
||||
}
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
end -= 1;
|
||||
}
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "chatml", LLM_CHAT_TEMPLATE_CHATML },
|
||||
{ "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 },
|
||||
{ "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS },
|
||||
{ "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS },
|
||||
{ "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP },
|
||||
{ "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 },
|
||||
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
||||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
{ "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
|
||||
{ "monarch", LLM_CHAT_TEMPLATE_MONARCH },
|
||||
{ "gemma", LLM_CHAT_TEMPLATE_GEMMA },
|
||||
{ "orion", LLM_CHAT_TEMPLATE_ORION },
|
||||
{ "openchat", LLM_CHAT_TEMPLATE_OPENCHAT },
|
||||
{ "vicuna", LLM_CHAT_TEMPLATE_VICUNA },
|
||||
{ "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA },
|
||||
{ "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK },
|
||||
{ "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 },
|
||||
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
||||
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
||||
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
||||
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
|
||||
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
||||
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
||||
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
||||
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
return LLM_CHAT_TEMPLATES.at(name);
|
||||
}
|
||||
|
||||
llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
try {
|
||||
return llm_chat_template_from_str(tmpl);
|
||||
} catch (const std::out_of_range &) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
auto tmpl_contains = [&tmpl](const char * haystack) -> bool {
|
||||
return tmpl.find(haystack) != std::string::npos;
|
||||
};
|
||||
if (tmpl_contains("<|im_start|>")) {
|
||||
return tmpl_contains("<|im_sep|>")
|
||||
? LLM_CHAT_TEMPLATE_PHI_4
|
||||
: LLM_CHAT_TEMPLATE_CHATML;
|
||||
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
||||
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
||||
} else if (
|
||||
// catches official 'v1' template
|
||||
tmpl_contains("' [INST] ' + system_message")
|
||||
// catches official 'v3' and 'v3-tekken' templates
|
||||
|| tmpl_contains("[AVAILABLE_TOOLS]")
|
||||
) {
|
||||
// Official mistral 'v1', 'v3' and 'v3-tekken' templates
|
||||
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
|
||||
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
|
||||
if (tmpl_contains(" [INST]")) {
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V1;
|
||||
} else if (tmpl_contains("\"[INST]\"")) {
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V3;
|
||||
} else {
|
||||
// llama2 template and its variants
|
||||
// [variant] support system message
|
||||
// See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
||||
bool support_system_message = tmpl_contains("<<SYS>>");
|
||||
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
|
||||
bool strip_message = tmpl_contains("content.strip()");
|
||||
if (strip_message) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
|
||||
} else if (add_bos_inside_history) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
|
||||
} else if (support_system_message) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_2_SYS;
|
||||
} else {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_2;
|
||||
}
|
||||
}
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_PHI_3;
|
||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
||||
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
||||
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
||||
} else if (tmpl_contains("bos_token + message['role']")) {
|
||||
return LLM_CHAT_TEMPLATE_MONARCH;
|
||||
} else if (tmpl_contains("<start_of_turn>")) {
|
||||
return LLM_CHAT_TEMPLATE_GEMMA;
|
||||
} else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
|
||||
// OrionStarAI/Orion-14B-Chat
|
||||
return LLM_CHAT_TEMPLATE_ORION;
|
||||
} else if (tmpl_contains("GPT4 Correct ")) {
|
||||
// openchat/openchat-3.5-0106
|
||||
return LLM_CHAT_TEMPLATE_OPENCHAT;
|
||||
} else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) {
|
||||
// eachadea/vicuna-13b-1.1 (and Orca variant)
|
||||
if (tmpl_contains("SYSTEM: ")) {
|
||||
return LLM_CHAT_TEMPLATE_VICUNA_ORCA;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_VICUNA;
|
||||
} else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) {
|
||||
// deepseek-ai/deepseek-coder-33b-instruct
|
||||
return LLM_CHAT_TEMPLATE_DEEPSEEK;
|
||||
} else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) {
|
||||
// CohereForAI/c4ai-command-r-plus
|
||||
return LLM_CHAT_TEMPLATE_COMMAND_R;
|
||||
} else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA_3;
|
||||
} else if (tmpl_contains("[gMASK]sop")) {
|
||||
// chatglm3-6b
|
||||
return LLM_CHAT_TEMPLATE_CHATGML_3;
|
||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATGML_4;
|
||||
} else if (tmpl_contains(LU8("<用户>"))) {
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
return LLM_CHAT_TEMPLATE_MINICPM;
|
||||
} else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
|
||||
return LLM_CHAT_TEMPLATE_DEEPSEEK_2;
|
||||
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
|
||||
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
|
||||
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
|
||||
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
||||
// EXAONE-3.0-7.8B-Instruct
|
||||
return LLM_CHAT_TEMPLATE_EXAONE_3;
|
||||
} else if (tmpl_contains("rwkv-world")) {
|
||||
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
|
||||
} else if (tmpl_contains("<|start_of_role|>")) {
|
||||
return LLM_CHAT_TEMPLATE_GRANITE;
|
||||
} else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) {
|
||||
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
||||
} else if (tmpl_contains("<|role_start|>")) {
|
||||
return LLM_CHAT_TEMPLATE_MEGREZ;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
|
||||
// Simple version of "llama_apply_chat_template" that only works with strings
|
||||
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
||||
int32_t llm_chat_apply_template(
|
||||
llm_chat_template tmpl,
|
||||
const std::vector<const llama_chat_message *> & chat,
|
||||
std::string & dest, bool add_ass) {
|
||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||
std::stringstream ss;
|
||||
if (tmpl == LLM_CHAT_TEMPLATE_CHATML) {
|
||||
// chatml template
|
||||
for (auto message : chat) {
|
||||
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
||||
// Official mistral 'v7' template
|
||||
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
||||
} else if (role == "user") {
|
||||
ss << "[INST] " << content << "[/INST]";
|
||||
}
|
||||
else {
|
||||
ss << " " << content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
||||
|| tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3
|
||||
|| tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) {
|
||||
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md
|
||||
// See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md
|
||||
std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : "";
|
||||
std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " ";
|
||||
bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3;
|
||||
bool is_inside_turn = false;
|
||||
for (auto message : chat) {
|
||||
if (!is_inside_turn) {
|
||||
ss << leading_space << "[INST]" << trailing_space;
|
||||
is_inside_turn = true;
|
||||
}
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << content << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << content << leading_space << "[/INST]";
|
||||
} else {
|
||||
ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "</s>";
|
||||
is_inside_turn = false;
|
||||
}
|
||||
}
|
||||
} else if (
|
||||
tmpl == LLM_CHAT_TEMPLATE_LLAMA_2
|
||||
|| tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS
|
||||
|| tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS
|
||||
|| tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) {
|
||||
// llama2 template and its variants
|
||||
// [variant] support system message
|
||||
// See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
||||
bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2;
|
||||
// [variant] add BOS inside history
|
||||
bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS;
|
||||
// [variant] trim spaces from the input message
|
||||
bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP;
|
||||
// construct the prompt
|
||||
bool is_inside_turn = true; // skip BOS at the beginning
|
||||
ss << "[INST] ";
|
||||
for (auto message : chat) {
|
||||
std::string content = strip_message ? trim(message->content) : message->content;
|
||||
std::string role(message->role);
|
||||
if (!is_inside_turn) {
|
||||
is_inside_turn = true;
|
||||
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
|
||||
}
|
||||
if (role == "system") {
|
||||
if (support_system_message) {
|
||||
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
|
||||
} else {
|
||||
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
|
||||
ss << content << "\n";
|
||||
}
|
||||
} else if (role == "user") {
|
||||
ss << content << " [/INST]";
|
||||
} else {
|
||||
ss << content << "</s>";
|
||||
is_inside_turn = false;
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) {
|
||||
// Phi 3
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) {
|
||||
// chatml template
|
||||
for (auto message : chat) {
|
||||
ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant<|im_sep|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
|
||||
// Falcon 3
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>\n" << message->content << "\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) {
|
||||
// zephyr template
|
||||
for (auto message : chat) {
|
||||
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) {
|
||||
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
|
||||
for (auto message : chat) {
|
||||
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
|
||||
ss << bos << message->role << "\n" << message->content << "</s>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<s>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) {
|
||||
// google/gemma-7b-it
|
||||
std::string system_prompt = "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
||||
system_prompt = trim(message->content);
|
||||
continue;
|
||||
}
|
||||
// in gemma, "assistant" is "model"
|
||||
role = role == "assistant" ? "model" : message->role;
|
||||
ss << "<start_of_turn>" << role << "\n";
|
||||
if (!system_prompt.empty() && role != "model") {
|
||||
ss << system_prompt << "\n\n";
|
||||
system_prompt = "";
|
||||
}
|
||||
ss << trim(message->content) << "<end_of_turn>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<start_of_turn>model\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_ORION) {
|
||||
// OrionStarAI/Orion-14B-Chat
|
||||
std::string system_prompt = "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// there is no system message support, we will merge it with user prompt
|
||||
system_prompt = message->content;
|
||||
continue;
|
||||
} else if (role == "user") {
|
||||
ss << "Human: ";
|
||||
if (!system_prompt.empty()) {
|
||||
ss << system_prompt << "\n\n";
|
||||
system_prompt = "";
|
||||
}
|
||||
ss << message->content << "\n\nAssistant: </s>";
|
||||
} else {
|
||||
ss << message->content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) {
|
||||
// openchat/openchat-3.5-0106,
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content << "<|end_of_turn|>";
|
||||
} else {
|
||||
role[0] = toupper(role[0]);
|
||||
ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "GPT4 Correct Assistant:";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
|
||||
// eachadea/vicuna-13b-1.1 (and Orca variant)
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// Orca-Vicuna variant uses a system prefix
|
||||
if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) {
|
||||
ss << "SYSTEM: " << message->content << "\n";
|
||||
} else {
|
||||
ss << message->content << "\n\n";
|
||||
}
|
||||
} else if (role == "user") {
|
||||
ss << "USER: " << message->content << "\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "ASSISTANT: " << message->content << "</s>\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "ASSISTANT:";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) {
|
||||
// deepseek-ai/deepseek-coder-33b-instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content;
|
||||
} else if (role == "user") {
|
||||
ss << "### Instruction:\n" << message->content << "\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "### Response:\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) {
|
||||
// CohereForAI/c4ai-command-r-plus
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
|
||||
} else if (role == "user") {
|
||||
ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) {
|
||||
// Llama 3
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
|
||||
// chatglm3-6b
|
||||
ss << "[gMASK]" << "sop";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n " << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
|
||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "user") {
|
||||
ss << LU8("<用户>");
|
||||
ss << trim(message->content);
|
||||
ss << "<AI>";
|
||||
} else {
|
||||
ss << trim(message->content);
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) {
|
||||
// DeepSeek-V2
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << "User: " << message->content << "\n\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) {
|
||||
// DeepSeek-V3
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << message->content << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << LU8("<|User|>") << message->content;
|
||||
} else if (role == "assistant") {
|
||||
ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>");
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << LU8("<|Assistant|>");
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) {
|
||||
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
||||
// EXAONE-3.0-7.8B-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
|
||||
} else if (role == "user") {
|
||||
ss << "[|user|]" << trim(message->content) << "\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "[|assistant|]";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
||||
// this template requires the model to have "\n\n" as EOT token
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "user") {
|
||||
ss << "User: " << message->content << "\n\nAssistant:";
|
||||
} else {
|
||||
ss << message->content << "\n\n";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
||||
// IBM Granite template
|
||||
for (const auto & message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|start_of_role|>" << role << "<|end_of_role|>";
|
||||
if (role == "assistant_tool_call") {
|
||||
ss << "<|tool_call|>";
|
||||
}
|
||||
ss << message->content << "<|end_of_text|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
|
||||
// GigaChat template
|
||||
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
|
||||
|
||||
// Handle system message if present
|
||||
if (has_system) {
|
||||
ss << "<s>" << chat[0]->content << "<|message_sep|>";
|
||||
} else {
|
||||
ss << "<s>";
|
||||
}
|
||||
|
||||
// Process remaining messages
|
||||
for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) {
|
||||
std::string role(chat[i]->role);
|
||||
if (role == "user") {
|
||||
ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>"
|
||||
<< "available functions<|role_sep|>[]<|message_sep|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>";
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt if needed
|
||||
if (add_ass) {
|
||||
ss << "assistant<|role_sep|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) {
|
||||
// Megrez template
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>";
|
||||
}
|
||||
|
||||
if (add_ass) {
|
||||
ss << "<|role_start|>assistant<|role_end|>";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
}
|
||||
dest = ss.str();
|
||||
return dest.size();
|
||||
}
|
||||
|
||||
// public interface
|
||||
|
||||
int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
|
||||
auto it = LLM_CHAT_TEMPLATES.begin();
|
||||
for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) {
|
||||
output[i] = it->first.c_str();
|
||||
std::advance(it, 1);
|
||||
}
|
||||
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
||||
}
|
||||
|
53
examples/talk-llama/llama-chat.h
Normal file
53
examples/talk-llama/llama-chat.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_CHATML,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_2,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_2_SYS,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V1,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
LLM_CHAT_TEMPLATE_ZEPHYR,
|
||||
LLM_CHAT_TEMPLATE_MONARCH,
|
||||
LLM_CHAT_TEMPLATE_GEMMA,
|
||||
LLM_CHAT_TEMPLATE_ORION,
|
||||
LLM_CHAT_TEMPLATE_OPENCHAT,
|
||||
LLM_CHAT_TEMPLATE_VICUNA,
|
||||
LLM_CHAT_TEMPLATE_VICUNA_ORCA,
|
||||
LLM_CHAT_TEMPLATE_DEEPSEEK,
|
||||
LLM_CHAT_TEMPLATE_DEEPSEEK_2,
|
||||
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
||||
LLM_CHAT_TEMPLATE_COMMAND_R,
|
||||
LLM_CHAT_TEMPLATE_LLAMA_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGML_3,
|
||||
LLM_CHAT_TEMPLATE_CHATGML_4,
|
||||
LLM_CHAT_TEMPLATE_GLMEDGE,
|
||||
LLM_CHAT_TEMPLATE_MINICPM,
|
||||
LLM_CHAT_TEMPLATE_EXAONE_3,
|
||||
LLM_CHAT_TEMPLATE_RWKV_WORLD,
|
||||
LLM_CHAT_TEMPLATE_GRANITE,
|
||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
struct llama_chat_message;
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name);
|
||||
|
||||
llm_chat_template llm_chat_detect_template(const std::string & tmpl);
|
||||
|
||||
int32_t llm_chat_apply_template(
|
||||
llm_chat_template tmpl,
|
||||
const std::vector<const llama_chat_message *> & chat,
|
||||
std::string & dest, bool add_ass);
|
1775
examples/talk-llama/llama-context.cpp
Normal file
1775
examples/talk-llama/llama-context.cpp
Normal file
File diff suppressed because it is too large
Load Diff
128
examples/talk-llama/llama-context.h
Normal file
128
examples/talk-llama/llama-context.h
Normal file
@ -0,0 +1,128 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-adapter.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model)
|
||||
: model(model)
|
||||
, t_start_us(model.t_start_us)
|
||||
, t_load_us(model.t_load_us) {}
|
||||
|
||||
const struct llama_model & model;
|
||||
|
||||
struct llama_cparams cparams;
|
||||
struct llama_sbatch sbatch; // TODO: revisit if needed
|
||||
struct llama_kv_cache kv_self;
|
||||
struct llama_adapter_cvec cvec;
|
||||
|
||||
std::unordered_map<struct llama_adapter_lora *, float> lora;
|
||||
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
mutable int64_t t_start_us;
|
||||
mutable int64_t t_load_us;
|
||||
mutable int64_t t_p_eval_us = 0;
|
||||
mutable int64_t t_eval_us = 0;
|
||||
|
||||
mutable int64_t t_compute_start_us = 0;
|
||||
mutable int64_t n_queued_tokens = 0;
|
||||
|
||||
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
bool logits_all = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
// whether we are computing encoder output or decoder output
|
||||
bool is_encoding = false;
|
||||
|
||||
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
||||
// number of position id each token get, 1 for each token in most cases.
|
||||
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
||||
int n_pos_per_token = 1;
|
||||
|
||||
// output of the encoder part of the encoder-decoder models
|
||||
std::vector<float> embd_enc;
|
||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
ggml_abort_callback abort_callback = nullptr;
|
||||
void * abort_callback_data = nullptr;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
||||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
||||
};
|
||||
|
||||
// TODO: make these methods of llama_context
|
||||
void llama_set_k_shift(struct llama_context & lctx);
|
||||
|
||||
void llama_set_s_copy(struct llama_context & lctx);
|
||||
|
||||
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
void llama_output_reorder(struct llama_context & ctx);
|
||||
|
||||
// For internal test use
|
||||
// TODO: remove
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|
1
examples/talk-llama/llama-cparams.cpp
Normal file
1
examples/talk-llama/llama-cparams.cpp
Normal file
@ -0,0 +1 @@
|
||||
#include "llama-cparams.h"
|
37
examples/talk-llama/llama-cparams.h
Normal file
37
examples/talk-llama/llama-cparams.h
Normal file
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
// These hyperparameters are not exposed in GGUF, because all
|
||||
// existing YaRN models use the same values for them.
|
||||
float yarn_ext_factor;
|
||||
float yarn_attn_factor;
|
||||
float yarn_beta_fast;
|
||||
float yarn_beta_slow;
|
||||
float defrag_thold;
|
||||
|
||||
bool embeddings;
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
bool no_perf;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
};
|
@ -1,5 +1,6 @@
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-sampling.h"
|
||||
|
||||
@ -559,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) {
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src);
|
||||
rules.clear();
|
||||
return false;
|
||||
}
|
||||
@ -822,15 +823,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
||||
return grammar->stacks;
|
||||
}
|
||||
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
const uint32_t chr,
|
||||
llama_grammar_stacks & stacks_new) {
|
||||
stacks_new.clear();
|
||||
stacks_new.reserve(stacks.size());
|
||||
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
|
||||
llama_grammar_stacks stacks_new;
|
||||
stacks_new.reserve(grammar->stacks.size());
|
||||
|
||||
for (const auto & stack : stacks) {
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
continue;
|
||||
}
|
||||
@ -844,9 +841,11 @@ void llama_grammar_accept(
|
||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||
new_stack.push_back(pos);
|
||||
}
|
||||
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
||||
llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
|
||||
}
|
||||
}
|
||||
|
||||
grammar->stacks = std::move(stacks_new);
|
||||
}
|
||||
|
||||
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||
@ -961,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy =*/ false,
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_tokens = */ {},
|
||||
/* .trigger_words = */ {},
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
llama_grammar_parser parser;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
@ -1036,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
||||
}
|
||||
} while (true);
|
||||
|
||||
std::vector<llama_token> vec_trigger_tokens;
|
||||
std::vector<std::string> vec_trigger_words;
|
||||
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
||||
GGML_ASSERT(trigger_tokens != nullptr);
|
||||
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
||||
}
|
||||
for (size_t i = 0; i < num_trigger_words; i++) {
|
||||
GGML_ASSERT(trigger_words != nullptr);
|
||||
vec_trigger_words.push_back(trigger_words[i]);
|
||||
}
|
||||
|
||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||
return new llama_grammar {
|
||||
vocab,
|
||||
std::move(vec_rules),
|
||||
std::move(stacks),
|
||||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ lazy,
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_buffer = */ "",
|
||||
std::move(vec_trigger_tokens),
|
||||
std::move(vec_trigger_words),
|
||||
};
|
||||
}
|
||||
|
||||
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||
@ -1051,7 +1089,17 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||
}
|
||||
|
||||
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
|
||||
llama_grammar * result = new llama_grammar {
|
||||
grammar.vocab,
|
||||
grammar.rules,
|
||||
grammar.stacks,
|
||||
grammar.partial_utf8,
|
||||
grammar.lazy,
|
||||
grammar.awaiting_trigger,
|
||||
grammar.trigger_buffer,
|
||||
grammar.trigger_tokens,
|
||||
grammar.trigger_words,
|
||||
};
|
||||
|
||||
// redirect elements in stacks to point to new rules
|
||||
for (size_t is = 0; is < result->stacks.size(); is++) {
|
||||
@ -1059,7 +1107,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
||||
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
||||
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
||||
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
||||
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1072,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
||||
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool allow_eog = false;
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
@ -1088,9 +1140,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
const llama_token id = cur_p->data[i].id;
|
||||
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
|
||||
const std::string & piece = grammar.vocab->token_to_piece(id);
|
||||
|
||||
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
|
||||
if (grammar.vocab->is_eog(id)) {
|
||||
if (!allow_eog) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
@ -1111,7 +1163,35 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
|
||||
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||
GGML_ASSERT(grammar.vocab != nullptr);
|
||||
|
||||
if (llama_token_is_eog_impl(*grammar.vocab, token)) {
|
||||
const auto & piece = grammar.vocab->token_to_piece(token);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
grammar.awaiting_trigger = false;
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||
return;
|
||||
} else {
|
||||
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
||||
grammar.trigger_buffer += piece;
|
||||
for (const auto & word : grammar.trigger_words) {
|
||||
auto pos = grammar.trigger_buffer.find(word);
|
||||
if (pos != std::string::npos) {
|
||||
grammar.awaiting_trigger = false;
|
||||
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (grammar.vocab->is_eog(token)) {
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
return;
|
||||
@ -1120,19 +1200,20 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
||||
llama_grammar_accept_str(grammar, piece);
|
||||
}
|
||||
|
||||
void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
|
||||
// Note terminating 0 in decoded string
|
||||
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
llama_grammar_stacks stacks_new;
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
||||
grammar.stacks = std::move(stacks_new);
|
||||
llama_grammar_accept(&grammar, *it);
|
||||
}
|
||||
|
||||
grammar.partial_utf8 = decoded.second;
|
||||
GGML_ASSERT(!grammar.stacks.empty());
|
||||
if (grammar.stacks.empty()) {
|
||||
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct llama_vocab;
|
||||
|
||||
@ -58,6 +60,7 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||
|
||||
// TODO: remove, needed for tests atm
|
||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||
|
||||
@ -65,11 +68,7 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
|
||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||
// produces the N possible stacks if the given char is accepted at those
|
||||
// positions
|
||||
void llama_grammar_accept(
|
||||
const llama_grammar_rules & rules,
|
||||
const llama_grammar_stacks & stacks,
|
||||
uint32_t chr,
|
||||
llama_grammar_stacks & stacks_new);
|
||||
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
|
||||
|
||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||
const llama_grammar_rules & rules,
|
||||
@ -115,6 +114,15 @@ struct llama_grammar {
|
||||
|
||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||
llama_partial_utf8 partial_utf8;
|
||||
|
||||
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
||||
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||
// (useful e.g. for tool_choice=required)
|
||||
bool lazy = false;
|
||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||
std::vector<std::string> trigger_words;
|
||||
};
|
||||
|
||||
//
|
||||
@ -128,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl(
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
||||
struct llama_grammar * llama_grammar_init_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||
|
||||
@ -142,3 +158,7 @@ void llama_grammar_apply_impl(
|
||||
void llama_grammar_accept_impl(
|
||||
struct llama_grammar & grammar,
|
||||
llama_token token);
|
||||
|
||||
void llama_grammar_accept_str(
|
||||
struct llama_grammar & grammar,
|
||||
const std::string & piece);
|
||||
|
71
examples/talk-llama/llama-hparams.cpp
Normal file
71
examples/talk-llama/llama-hparams.cpp
Normal file
@ -0,0 +1,71 @@
|
||||
#include "llama-hparams.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
uint32_t llama_hparams::n_head(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_arr[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_head_kv(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_kv_arr[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_ff(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_ff_arr[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_gqa(uint32_t il) const {
|
||||
const uint32_t n_head = this->n_head(il);
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
if (n_head_kv == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return n_head/n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_k * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
const uint32_t n_head_kv = this->n_head_kv(il);
|
||||
|
||||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return token_shift_count * n_embd;
|
||||
}
|
||||
|
||||
// TODO: maybe support other convolution strides than 1
|
||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
}
|
||||
|
||||
// corresponds to Mamba's ssm_states size
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
139
examples/talk-llama/llama-hparams.h
Normal file
139
examples/talk-llama/llama-hparams.h
Normal file
@ -0,0 +1,139 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
// bump if necessary
|
||||
#define LLAMA_MAX_LAYERS 512
|
||||
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
|
||||
|
||||
enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
};
|
||||
|
||||
struct llama_hparams_convnext {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
};
|
||||
|
||||
struct llama_hparams {
|
||||
bool vocab_only;
|
||||
bool rope_finetuned;
|
||||
bool use_par_res;
|
||||
bool swin_norm;
|
||||
|
||||
uint32_t n_ctx_train; // context size the model was trained on
|
||||
uint32_t n_embd;
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_rel_attn_bkts = 0;
|
||||
|
||||
// for WavTokenizer
|
||||
struct llama_hparams_posnet posnet;
|
||||
struct llama_hparams_convnext convnext;
|
||||
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||
|
||||
uint32_t n_layer_dense_lead = 0;
|
||||
uint32_t n_lora_q = 0;
|
||||
uint32_t n_lora_kv = 0;
|
||||
uint32_t n_ff_exp = 0;
|
||||
uint32_t n_ff_shexp = 0;
|
||||
uint32_t n_expert_shared = 0;
|
||||
uint32_t n_norm_groups = 0;
|
||||
|
||||
float expert_weights_scale = 0.0;
|
||||
bool expert_weights_norm = false;
|
||||
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
float f_norm_group_eps;
|
||||
|
||||
float f_attn_logit_softcapping = 50.0f;
|
||||
float f_final_logit_softcapping = 30.0f;
|
||||
|
||||
// for RWKV
|
||||
uint32_t rescale_every_n_layers = 0;
|
||||
uint32_t time_mix_extra_dim = 0;
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
uint32_t token_shift_count = 2;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_ctx_orig_yarn;
|
||||
float rope_yarn_log_mul;
|
||||
|
||||
std::array<int, 4> rope_sections;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
float f_logit_scale = 0.0f;
|
||||
|
||||
// Additional scale factors (Granite/Granite MoE)
|
||||
float f_residual_scale = 0.0f;
|
||||
float f_embedding_scale = 0.0f;
|
||||
float f_attention_scale = 0.0f;
|
||||
|
||||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
uint32_t n_head(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_head_kv(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_ff(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_gqa(uint32_t il = 0) const;
|
||||
|
||||
// dimension of key embeddings across all k-v heads
|
||||
uint32_t n_embd_k_gqa(uint32_t il = 0) const;
|
||||
|
||||
// dimension of value embeddings across all k-v heads
|
||||
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_k_s() const;
|
||||
|
||||
// dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_v_s() const;
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||
|
167
examples/talk-llama/llama-impl.cpp
Normal file
167
examples/talk-llama/llama-impl.cpp
Normal file
@ -0,0 +1,167 @@
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include "gguf.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
struct llama_logger_state {
|
||||
ggml_log_callback log_callback = llama_log_callback_default;
|
||||
void * log_callback_user_data = nullptr;
|
||||
};
|
||||
|
||||
static llama_logger_state g_logger_state;
|
||||
|
||||
time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
time_meas::~time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_log_set(ggml_log_callback log_callback, void * user_data) {
|
||||
ggml_log_set(log_callback, user_data);
|
||||
g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
|
||||
g_logger_state.log_callback_user_data = user_data;
|
||||
}
|
||||
|
||||
static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
|
||||
va_list args_copy;
|
||||
va_copy(args_copy, args);
|
||||
char buffer[128];
|
||||
int len = vsnprintf(buffer, 128, format, args);
|
||||
if (len < 128) {
|
||||
g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
|
||||
} else {
|
||||
char * buffer2 = new char[len + 1];
|
||||
vsnprintf(buffer2, len + 1, format, args_copy);
|
||||
buffer2[len] = 0;
|
||||
g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
|
||||
delete[] buffer2;
|
||||
}
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
void llama_log_internal(ggml_log_level level, const char * format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
llama_log_internal_v(level, format, args);
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
|
||||
(void) level;
|
||||
(void) user_data;
|
||||
fputs(text, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
if (search.empty()) {
|
||||
return;
|
||||
}
|
||||
std::string builder;
|
||||
builder.reserve(s.length());
|
||||
size_t pos = 0;
|
||||
size_t last_pos = 0;
|
||||
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
||||
builder.append(s, last_pos, pos - last_pos);
|
||||
builder.append(replace);
|
||||
last_pos = pos + search.length();
|
||||
}
|
||||
builder.append(s, last_pos, std::string::npos);
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
std::string format(const char * fmt, ...) {
|
||||
va_list ap;
|
||||
va_list ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||
std::vector<char> buf(size + 1);
|
||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||
GGML_ASSERT(size2 == size);
|
||||
va_end(ap2);
|
||||
va_end(ap);
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) {
|
||||
char buf[256];
|
||||
snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
|
||||
for (size_t i = 1; i < ne.size(); i++) {
|
||||
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
|
||||
char buf[256];
|
||||
snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
|
||||
switch (type) {
|
||||
case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
|
||||
case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
|
||||
case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
|
||||
case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
|
||||
case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
|
||||
case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
|
||||
case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
|
||||
case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
|
||||
case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
|
||||
case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
|
||||
case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
|
||||
default: return format("unknown type %d", type);
|
||||
}
|
||||
}
|
||||
|
||||
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
||||
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
|
||||
|
||||
switch (type) {
|
||||
case GGUF_TYPE_STRING:
|
||||
return gguf_get_val_str(ctx_gguf, i);
|
||||
case GGUF_TYPE_ARRAY:
|
||||
{
|
||||
const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
|
||||
int arr_n = gguf_get_arr_n(ctx_gguf, i);
|
||||
const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (int j = 0; j < arr_n; j++) {
|
||||
if (arr_type == GGUF_TYPE_STRING) {
|
||||
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
|
||||
// escape quotes
|
||||
replace_all(val, "\\", "\\\\");
|
||||
replace_all(val, "\"", "\\\"");
|
||||
ss << '"' << val << '"';
|
||||
} else if (arr_type == GGUF_TYPE_ARRAY) {
|
||||
ss << "???";
|
||||
} else {
|
||||
ss << gguf_data_to_str(arr_type, data, j);
|
||||
}
|
||||
if (j < arr_n - 1) {
|
||||
ss << ", ";
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
return ss.str();
|
||||
}
|
||||
default:
|
||||
return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
|
||||
}
|
||||
}
|
@ -1,10 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "ggml.h" // for ggml_log_level
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
@ -35,147 +34,28 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
|
||||
// helpers
|
||||
//
|
||||
|
||||
struct time_meas {
|
||||
time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
template <typename T>
|
||||
struct no_init {
|
||||
T value;
|
||||
no_init() { /* do nothing */ }
|
||||
};
|
||||
|
||||
~time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
struct time_meas {
|
||||
time_meas(int64_t & t_acc, bool disable = false);
|
||||
~time_meas();
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
if (search.empty()) {
|
||||
return;
|
||||
}
|
||||
std::string builder;
|
||||
builder.reserve(s.length());
|
||||
size_t pos = 0;
|
||||
size_t last_pos = 0;
|
||||
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
||||
builder.append(s, last_pos, pos - last_pos);
|
||||
builder.append(replace);
|
||||
last_pos = pos + search.length();
|
||||
}
|
||||
builder.append(s, last_pos, std::string::npos);
|
||||
s = std::move(builder);
|
||||
}
|
||||
void replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||
struct llama_context * ctx
|
||||
);
|
||||
// TODO: rename to llama_format ?
|
||||
LLAMA_ATTRIBUTE_FORMAT(1, 2)
|
||||
std::string format(const char * fmt, ...);
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
template<typename T>
|
||||
struct ring_buffer {
|
||||
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||
std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
|
||||
std::string llama_format_tensor_shape(const struct ggml_tensor * t);
|
||||
|
||||
T & front() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[first];
|
||||
}
|
||||
|
||||
const T & front() const {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[first];
|
||||
}
|
||||
|
||||
T & back() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[pos];
|
||||
}
|
||||
|
||||
const T & back() const {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[pos];
|
||||
}
|
||||
|
||||
void push_back(const T & value) {
|
||||
if (capacity == 0) {
|
||||
throw std::runtime_error("ring buffer: capacity is zero");
|
||||
}
|
||||
|
||||
if (sz == capacity) {
|
||||
// advance the start when buffer is full
|
||||
first = (first + 1) % capacity;
|
||||
} else {
|
||||
sz++;
|
||||
}
|
||||
data[pos] = value;
|
||||
pos = (pos + 1) % capacity;
|
||||
}
|
||||
|
||||
T pop_front() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
T value = data[first];
|
||||
first = (first + 1) % capacity;
|
||||
sz--;
|
||||
return value;
|
||||
}
|
||||
|
||||
//T & operator[](size_t i) {
|
||||
// if (i >= sz) {
|
||||
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||
// }
|
||||
// return data[(first + i) % capacity];
|
||||
//}
|
||||
|
||||
//const T & at(size_t i) const {
|
||||
// if (i >= sz) {
|
||||
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||
// }
|
||||
// return data[(first + i) % capacity];
|
||||
//}
|
||||
|
||||
const T & rat(size_t i) const {
|
||||
if (i >= sz) {
|
||||
throw std::runtime_error("ring buffer: index out of bounds");
|
||||
}
|
||||
return data[(first + sz - i - 1) % capacity];
|
||||
}
|
||||
|
||||
std::vector<T> to_vector() const {
|
||||
std::vector<T> result;
|
||||
result.reserve(sz);
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
result.push_back(data[(first + i) % capacity]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
// here only reset the status of the buffer
|
||||
sz = 0;
|
||||
first = 0;
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return sz == 0;
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return sz;
|
||||
}
|
||||
|
||||
size_t capacity = 0;
|
||||
size_t sz = 0;
|
||||
size_t first = 0;
|
||||
size_t pos = 0;
|
||||
std::vector<T> data;
|
||||
};
|
||||
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
|
||||
|
718
examples/talk-llama/llama-kv-cache.cpp
Normal file
718
examples/talk-llama/llama-kv-cache.cpp
Normal file
@ -0,0 +1,718 @@
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
|
||||
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
||||
|
||||
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
|
||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||
return cparams.flash_attn ? 256u : 32u;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_init(
|
||||
struct llama_kv_cache & cache,
|
||||
const llama_model & model,
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload) {
|
||||
const struct llama_hparams & hparams = model.hparams;
|
||||
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
cache.recurrent = llama_model_is_recurrent(&model);
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
||||
cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
|
||||
|
||||
LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
|
||||
__func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = kv_size;
|
||||
cache.used = 0;
|
||||
|
||||
cache.type_k = type_k;
|
||||
cache.type_v = type_v;
|
||||
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(kv_size);
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
if (!ctx) {
|
||||
return nullptr;
|
||||
}
|
||||
ctx_map[buft] = ctx;
|
||||
cache.ctxs.emplace_back(ctx);
|
||||
return ctx;
|
||||
}
|
||||
return it->second;
|
||||
};
|
||||
|
||||
cache.k_l.reserve(n_layer);
|
||||
cache.v_l.reserve(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
|
||||
|
||||
ggml_backend_buffer_type_t buft;
|
||||
if (offload) {
|
||||
auto * dev = model.dev_layer(i);
|
||||
buft = ggml_backend_dev_buffer_type(dev);
|
||||
} else {
|
||||
buft = ggml_backend_cpu_buffer_type();
|
||||
}
|
||||
ggml_context * ctx = ctx_for_buft(buft);
|
||||
|
||||
if (!ctx) {
|
||||
LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
cache.k_l.push_back(k);
|
||||
cache.v_l.push_back(v);
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
for (auto it : ctx_map) {
|
||||
auto * buft = it.first;
|
||||
auto * ctx = it.second;
|
||||
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
if (!buf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_clear(buf, 0);
|
||||
LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
||||
cache.bufs.emplace_back(buf);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_ubatch & ubatch) {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
if (cache.recurrent) {
|
||||
// For recurrent state architectures (like Mamba or RWKV),
|
||||
// each cache cell can store the state for a whole sequence.
|
||||
// A slot should be always be contiguous.
|
||||
|
||||
// can only process batches with an equal number of new tokens in each sequence
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
|
||||
int32_t min = cache.size - 1;
|
||||
int32_t max = 0;
|
||||
|
||||
// everything should fit if all seq_ids are smaller than the max
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
|
||||
// too big seq_id
|
||||
// TODO: would it be possible to resize the cache instead?
|
||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
||||
return llama_kv_cache_slot_info_failed;
|
||||
}
|
||||
if (j > 0) {
|
||||
llama_kv_cell & seq = cache.cells[seq_id];
|
||||
if (seq.tail >= 0) {
|
||||
llama_kv_cell & cell = cache.cells[seq.tail];
|
||||
// clear cells from seq_ids that become shared
|
||||
// (should not normally happen, but let's handle it anyway)
|
||||
cell.seq_id.erase(seq_id);
|
||||
seq.tail = -1;
|
||||
if (cell.seq_id.empty()) {
|
||||
cell.pos = -1;
|
||||
cell.src = -1;
|
||||
cache.used -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
{
|
||||
std::vector<int32_t> tails_verif;
|
||||
tails_verif.assign(cache.size, -1);
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
llama_kv_cell & cell = cache.cells[i];
|
||||
for (llama_seq_id seq_id : cell.seq_id) {
|
||||
if (tails_verif[seq_id] != -1) {
|
||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
||||
}
|
||||
tails_verif[seq_id] = i;
|
||||
}
|
||||
}
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (tails_verif[i] != cache.cells[i].tail) {
|
||||
LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// find next empty cell
|
||||
uint32_t next_empty_cell = cache.head;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
||||
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
next_empty_cell += 1;
|
||||
}
|
||||
|
||||
// find usable cell range
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
llama_kv_cell & seq_meta = cache.cells[seq_id];
|
||||
bool has_cell = false;
|
||||
if (seq_meta.tail >= 0) {
|
||||
llama_kv_cell & cell = cache.cells[seq_meta.tail];
|
||||
GGML_ASSERT(cell.has_seq_id(seq_id));
|
||||
// does this seq_id "own" the cell?
|
||||
if (cell.seq_id.size() == 1) { has_cell = true; }
|
||||
}
|
||||
if (!has_cell) {
|
||||
llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
|
||||
GGML_ASSERT(empty_cell.is_empty());
|
||||
// copy old tail into the empty cell
|
||||
if (seq_meta.tail >= 0) {
|
||||
llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
|
||||
empty_cell.pos = orig_cell.pos;
|
||||
empty_cell.src = orig_cell.src;
|
||||
orig_cell.seq_id.erase(seq_id);
|
||||
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
||||
}
|
||||
seq_meta.tail = next_empty_cell;
|
||||
// find next empty cell
|
||||
if (s + 1 < n_seqs) {
|
||||
next_empty_cell += 1;
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
|
||||
llama_kv_cell & cell = cache.cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
next_empty_cell += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
||||
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
||||
}
|
||||
|
||||
// gather and re-order
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
int32_t dst_id = s + min;
|
||||
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
|
||||
if (dst_id != src_id) {
|
||||
llama_kv_cell & dst_cell = cache.cells[dst_id];
|
||||
llama_kv_cell & src_cell = cache.cells[src_id];
|
||||
|
||||
std::swap(dst_cell.pos, src_cell.pos);
|
||||
std::swap(dst_cell.src, src_cell.src);
|
||||
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
||||
|
||||
// swap tails (assuming they NEVER overlap)
|
||||
for (const llama_seq_id seq_id : src_cell.seq_id) {
|
||||
cache.cells[seq_id].tail = src_id;
|
||||
}
|
||||
for (const llama_seq_id seq_id : dst_cell.seq_id) {
|
||||
cache.cells[seq_id].tail = dst_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update the pos of the used seqs
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
int32_t cell_id = s + min;
|
||||
llama_kv_cell & cell = cache.cells[cell_id];
|
||||
|
||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
||||
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
||||
}
|
||||
cell.pos = last_pos;
|
||||
cell.seq_id.clear();
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
cell.seq_id.insert(seq_id);
|
||||
cache.cells[seq_id].tail = cell_id;
|
||||
}
|
||||
}
|
||||
|
||||
// allow getting the range of used cells, from head to head + n
|
||||
cache.head = min;
|
||||
cache.n = max - min + 1;
|
||||
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
|
||||
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
||||
|
||||
// sanity check
|
||||
return llama_kv_cache_slot_info(cache.n >= n_seqs);
|
||||
}
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > cache.size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
||||
return llama_kv_cache_slot_info_failed;
|
||||
}
|
||||
|
||||
uint32_t n_tested = 0;
|
||||
|
||||
while (true) {
|
||||
if (cache.head + n_tokens > cache.size) {
|
||||
n_tested += cache.size - cache.head;
|
||||
cache.head = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool found = true;
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (cache.cells[cache.head + i].pos >= 0) {
|
||||
found = false;
|
||||
cache.head += i + 1;
|
||||
n_tested += i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (n_tested >= cache.size) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return llama_kv_cache_slot_info_failed;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_seqs; s++) {
|
||||
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
||||
uint32_t k = s*n_seq_tokens + i;
|
||||
cache.cells[cache.head + k].pos = ubatch.pos[k];
|
||||
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
||||
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache.used += n_tokens;
|
||||
|
||||
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
||||
for (uint32_t i = cache.size; i > 0; --i) {
|
||||
const llama_kv_cell & cell = cache.cells[i - 1];
|
||||
|
||||
if (cell.pos >= 0 && !cell.is_empty()) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
||||
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
cache.cells[i].src = -1;
|
||||
cache.cells[i].tail = -1;
|
||||
}
|
||||
cache.head = 0;
|
||||
cache.used = 0;
|
||||
|
||||
for (auto & buf : cache.bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_seq_rm(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
uint32_t new_head = cache.size;
|
||||
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
// models like Mamba or RWKV can't have a state partially erased
|
||||
if (cache.recurrent) {
|
||||
if (seq_id >= (int64_t) cache.size) {
|
||||
// could be fatal
|
||||
return false;
|
||||
}
|
||||
if (0 <= seq_id) {
|
||||
int32_t & tail_id = cache.cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
const llama_kv_cell & cell = cache.cells[tail_id];
|
||||
// partial intersection is invalid
|
||||
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
||||
return false;
|
||||
}
|
||||
// invalidate tails which will be cleared
|
||||
if (p0 <= cell.pos && cell.pos < p1) {
|
||||
tail_id = -1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// seq_id is negative, then the range should include everything or nothing
|
||||
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
if (seq_id < 0) {
|
||||
cache.cells[i].seq_id.clear();
|
||||
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
cache.cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
if (cache.cells[i].is_empty()) {
|
||||
// keep count of the number of used cells
|
||||
if (cache.cells[i].pos >= 0) cache.used--;
|
||||
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].src = -1;
|
||||
if (new_head == cache.size) new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_cp(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
if (cache.recurrent) {
|
||||
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
||||
llama_kv_cell & tail_src = cache.cells[seq_id_src];
|
||||
llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
|
||||
if (tail_dst.tail >= 0) {
|
||||
// clear destination seq_id if it wasn't empty
|
||||
llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
|
||||
|
||||
cell_dst.seq_id.erase(seq_id_dst);
|
||||
tail_dst.tail = -1;
|
||||
if (cell_dst.seq_id.empty()) {
|
||||
cell_dst.pos = -1;
|
||||
cell_dst.delta = -1;
|
||||
cell_dst.src = -1;
|
||||
cache.used -= 1;
|
||||
}
|
||||
}
|
||||
if (tail_src.tail >= 0) {
|
||||
llama_kv_cell & cell_src = cache.cells[tail_src.tail];
|
||||
|
||||
cell_src.seq_id.insert(seq_id_dst);
|
||||
tail_dst.tail = tail_src.tail;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
// otherwise, this is the KV cache of a Transformer-like model
|
||||
|
||||
cache.head = 0;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.cells[i].seq_id.insert(seq_id_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
uint32_t new_head = cache.size;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.recurrent && (llama_seq_id) i != seq_id) {
|
||||
cache.cells[i].tail = -1;
|
||||
}
|
||||
if (!cache.cells[i].has_seq_id(seq_id)) {
|
||||
if (cache.cells[i].pos >= 0) cache.used--;
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].src = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
if (new_head == cache.size) new_head = i;
|
||||
} else {
|
||||
cache.cells[i].seq_id.clear();
|
||||
cache.cells[i].seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_add(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
uint32_t new_head = cache.size;
|
||||
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
// If there is no range then return early to avoid looping over the cache.
|
||||
if (p0 == p1) return;
|
||||
|
||||
if (cache.recurrent) {
|
||||
// for Mamba-like or RWKV models, only the pos needs to be shifted
|
||||
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||
const int32_t tail_id = cache.cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
llama_kv_cell & cell = cache.cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos += delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.has_shift = true;
|
||||
cache.cells[i].pos += delta;
|
||||
cache.cells[i].delta += delta;
|
||||
|
||||
if (cache.cells[i].pos < 0) {
|
||||
if (!cache.cells[i].is_empty()) {
|
||||
cache.used--;
|
||||
}
|
||||
cache.cells[i].pos = -1;
|
||||
cache.cells[i].seq_id.clear();
|
||||
if (new_head == cache.size) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
// Otherwise we just start the next search from the beginning.
|
||||
cache.head = new_head != cache.size ? new_head : 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_div(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
// If there is no range then return early to avoid looping over the cache.
|
||||
if (p0 == p1) return;
|
||||
|
||||
if (cache.recurrent) {
|
||||
// for Mamba-like or RWKV models, only the pos needs to be changed
|
||||
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||
const int32_t tail_id = cache.cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
llama_kv_cell & cell = cache.cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos /= d;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||
cache.has_shift = true;
|
||||
|
||||
{
|
||||
llama_pos p_old = cache.cells[i].pos;
|
||||
cache.cells[i].pos /= d;
|
||||
cache.cells[i].delta += cache.cells[i].pos - p_old;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
|
||||
llama_pos result = 0;
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
if (cache.cells[i].has_seq_id(seq_id)) {
|
||||
result = std::max(result, cache.cells[i].pos);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||
if (!cache.recurrent) {
|
||||
cache.do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
|
||||
int result = 0;
|
||||
|
||||
for (uint32_t i = 0; i < kv.size; i++) {
|
||||
result += kv.cells[i].seq_id.size();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
|
||||
return kv.used;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
|
||||
return kv.can_shift;
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
|
||||
struct llama_kv_cache_view result = {
|
||||
/*.n_cells = */ 0,
|
||||
/*.n_seq_max = */ n_seq_max,
|
||||
/*.token_count = */ 0,
|
||||
/*.used_cells = */ llama_get_kv_cache_used_cells(kv),
|
||||
/*.max_contiguous = */ 0,
|
||||
/*.max_contiguous_idx = */ -1,
|
||||
/*.cells = */ nullptr,
|
||||
/*.cells_sequences = */ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
|
||||
if (view->cells != nullptr) {
|
||||
free(view->cells);
|
||||
view->cells = nullptr;
|
||||
}
|
||||
if (view->cells_sequences != nullptr) {
|
||||
free(view->cells_sequences);
|
||||
view->cells_sequences = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
|
||||
if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
|
||||
view->n_cells = int32_t(kv.size);
|
||||
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
|
||||
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
|
||||
view->cells = (struct llama_kv_cache_view_cell *)p;
|
||||
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
|
||||
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
|
||||
view->cells_sequences = (llama_seq_id *)p;
|
||||
}
|
||||
|
||||
const std::vector<llama_kv_cell> & kv_cells = kv.cells;
|
||||
llama_kv_cache_view_cell * c_curr = view->cells;
|
||||
llama_seq_id * cs_curr = view->cells_sequences;
|
||||
int32_t used_cells = 0;
|
||||
int32_t token_count = 0;
|
||||
int32_t curr_contig_idx = -1;
|
||||
uint32_t max_contig = 0;
|
||||
int32_t max_contig_idx = -1;
|
||||
|
||||
for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
|
||||
const size_t curr_size = kv_cells[i].seq_id.size();
|
||||
token_count += curr_size;
|
||||
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
|
||||
|
||||
if (curr_size > 0) {
|
||||
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
|
||||
max_contig = i - curr_contig_idx;
|
||||
max_contig_idx = curr_contig_idx;
|
||||
}
|
||||
curr_contig_idx = -1;
|
||||
} else if (curr_contig_idx < 0) {
|
||||
curr_contig_idx = i;
|
||||
}
|
||||
|
||||
int seq_idx = 0;
|
||||
for (const llama_seq_id it : kv_cells[i].seq_id) {
|
||||
if (seq_idx >= view->n_seq_max) {
|
||||
break;
|
||||
}
|
||||
cs_curr[seq_idx] = it;
|
||||
seq_idx++;
|
||||
}
|
||||
if (seq_idx != 0) {
|
||||
used_cells++;
|
||||
}
|
||||
for (; seq_idx < view->n_seq_max; seq_idx++) {
|
||||
cs_curr[seq_idx] = -1;
|
||||
}
|
||||
}
|
||||
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
|
||||
max_contig_idx = curr_contig_idx;
|
||||
max_contig = kv_cells.size() - curr_contig_idx;
|
||||
}
|
||||
view->max_contiguous = max_contig;
|
||||
view->max_contiguous_idx = max_contig_idx;
|
||||
view->token_count = token_count;
|
||||
view->used_cells = used_cells;
|
||||
if (uint32_t(used_cells) != kv.used) {
|
||||
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
|
||||
__func__, kv.used, used_cells);
|
||||
}
|
||||
}
|
218
examples/talk-llama/llama-kv-cache.h
Normal file
218
examples/talk-llama/llama-kv-cache.h
Normal file
@ -0,0 +1,218 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
int32_t src = -1; // used by recurrent state models to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const llama_kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
// ring-buffer of cached KV data
|
||||
struct llama_kv_cache {
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||
std::vector<struct ggml_tensor *> v_l;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
size_t total_size() const {
|
||||
size_t size = 0;
|
||||
for (const auto & buf : bufs) {
|
||||
size += ggml_backend_buffer_get_size(buf.get());
|
||||
}
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos max_pos() const {
|
||||
llama_pos max_pos = -1;
|
||||
for (const auto & cell : cells) {
|
||||
max_pos = std::max(max_pos, cell.pos);
|
||||
}
|
||||
|
||||
return max_pos;
|
||||
}
|
||||
};
|
||||
|
||||
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
||||
struct llama_kv_cache_slot_info {
|
||||
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
||||
bool found = false; // the slot was found
|
||||
|
||||
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
||||
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
||||
|
||||
operator bool() const { return found; }
|
||||
};
|
||||
|
||||
// TODO: maybe not needed
|
||||
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
|
||||
|
||||
bool llama_kv_cache_init(
|
||||
struct llama_kv_cache & cache,
|
||||
const llama_model & model,
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload);
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
// returns a structure holding information about the slot found
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_ubatch & batch);
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
|
||||
|
||||
void llama_kv_cache_clear(struct llama_kv_cache & cache);
|
||||
|
||||
bool llama_kv_cache_seq_rm(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
void llama_kv_cache_seq_cp(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
|
||||
void llama_kv_cache_seq_keep(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
void llama_kv_cache_seq_add(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
|
||||
void llama_kv_cache_seq_div(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
|
||||
llama_pos llama_kv_cache_seq_pos_max(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
void llama_kv_cache_defrag(struct llama_kv_cache & cache);
|
||||
|
||||
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
|
||||
|
||||
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
|
||||
|
||||
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
||||
|
||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
||||
|
||||
//
|
||||
// kv cache restore
|
||||
//
|
||||
|
||||
// saves the kv_cache state for future recovery.
|
||||
// used to rollback llama_kv_cache_find_slot changes.
|
||||
struct llama_kv_slot_restorer {
|
||||
struct llama_kv_cache_state {
|
||||
uint32_t head = 0;
|
||||
uint32_t n = 0;
|
||||
} old_state;
|
||||
|
||||
// for non-recurrent models only
|
||||
// list of slots to restore
|
||||
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
|
||||
|
||||
bool do_restore = false;
|
||||
|
||||
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
||||
old_state.head = cache.head;
|
||||
old_state.n = cache.n;
|
||||
}
|
||||
|
||||
// saves a slot information for future restoration
|
||||
void save(const struct llama_kv_cache_slot_info & slot) {
|
||||
if (slot) {
|
||||
do_restore = true;
|
||||
if (slot.boundaries.first != slot.boundaries.second) {
|
||||
slot_boundaries.push_back(slot.boundaries);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// must be explicitly called to restore the kv_cache state
|
||||
// and rollback changes from all llama_kv_cache_find_slot calls
|
||||
void restore(struct llama_kv_cache & cache) {
|
||||
if (do_restore) {
|
||||
cache.head = old_state.head;
|
||||
cache.n = old_state.n;
|
||||
|
||||
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
||||
llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
||||
} else {
|
||||
for (auto & slot : slot_boundaries) {
|
||||
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
590
examples/talk-llama/llama-mmap.cpp
Normal file
590
examples/talk-llama/llama-mmap.cpp
Normal file
@ -0,0 +1,590 @@
|
||||
#include "llama-mmap.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
#include <stdexcept>
|
||||
#include <cerrno>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#ifndef PATH_MAX
|
||||
#define PATH_MAX MAX_PATH
|
||||
#endif
|
||||
#include <io.h>
|
||||
#endif
|
||||
|
||||
// TODO: consider moving to llama-impl.h if needed in more places
|
||||
#if defined(_WIN32)
|
||||
static std::string llama_format_win_err(DWORD err) {
|
||||
LPSTR buf;
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
||||
if (!size) {
|
||||
return "FormatMessageA failed";
|
||||
}
|
||||
std::string ret(buf, size);
|
||||
LocalFree(buf);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
// llama_file
|
||||
|
||||
struct llama_file::impl {
|
||||
#if defined(_WIN32)
|
||||
HANDLE fp_win32;
|
||||
std::string GetErrorMessageWin32(DWORD error_code) const {
|
||||
std::string ret;
|
||||
LPSTR lpMsgBuf = NULL;
|
||||
DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
|
||||
if (!bufLen) {
|
||||
ret = format("Win32 error code: %lx", error_code);
|
||||
} else {
|
||||
ret = lpMsgBuf;
|
||||
LocalFree(lpMsgBuf);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
impl(const char * fname, const char * mode) {
|
||||
fp = ggml_fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
}
|
||||
fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
seek(0, SEEK_SET);
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
LARGE_INTEGER li;
|
||||
li.QuadPart = 0;
|
||||
BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
|
||||
if (!ret) {
|
||||
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||
}
|
||||
|
||||
return li.QuadPart;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) const {
|
||||
static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
|
||||
static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
|
||||
static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
|
||||
|
||||
LARGE_INTEGER li;
|
||||
li.QuadPart = offset;
|
||||
BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
|
||||
if (!ret) {
|
||||
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
size_t bytes_read = 0;
|
||||
while (bytes_read < len) {
|
||||
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
|
||||
DWORD chunk_read = 0;
|
||||
BOOL result = ReadFile(fp_win32, reinterpret_cast<char*>(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
|
||||
if (!result) {
|
||||
throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||
}
|
||||
if (chunk_read < chunk_size || chunk_read == 0) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
|
||||
bytes_read += chunk_read;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() const {
|
||||
uint32_t val;
|
||||
read_raw(&val, sizeof(val));
|
||||
return val;
|
||||
}
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const {
|
||||
size_t bytes_written = 0;
|
||||
while (bytes_written < len) {
|
||||
size_t chunk_size = std::min<size_t>(len - bytes_written, 64*1024*1024);
|
||||
DWORD chunk_written = 0;
|
||||
BOOL result = WriteFile(fp_win32, reinterpret_cast<char const*>(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
|
||||
if (!result) {
|
||||
throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
|
||||
}
|
||||
if (chunk_written < chunk_size || chunk_written == 0) {
|
||||
throw std::runtime_error("unexpectedly failed to write bytes");
|
||||
}
|
||||
|
||||
bytes_written += chunk_written;
|
||||
}
|
||||
}
|
||||
|
||||
void write_u32(uint32_t val) const {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
#else
|
||||
impl(const char * fname, const char * mode) {
|
||||
fp = ggml_fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
}
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
seek(0, SEEK_SET);
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
// TODO: this ifdef is never true?
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
if (ret == -1) {
|
||||
throw std::runtime_error(format("ftell error: %s", strerror(errno)));
|
||||
}
|
||||
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) const {
|
||||
// TODO: this ifdef is never true?
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
if (ret != 0) {
|
||||
throw std::runtime_error(format("seek error: %s", strerror(errno)));
|
||||
}
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error("unexpectedly reached end of file");
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t read_u32() const {
|
||||
uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const {
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
size_t ret = std::fwrite(ptr, len, 1, fp);
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error(format("write error: %s", strerror(errno)));
|
||||
}
|
||||
}
|
||||
|
||||
void write_u32(uint32_t val) const {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique<impl>(fname, mode)) {}
|
||||
llama_file::~llama_file() = default;
|
||||
|
||||
size_t llama_file::tell() const { return pimpl->tell(); }
|
||||
size_t llama_file::size() const { return pimpl->size; }
|
||||
|
||||
int llama_file::file_id() const {
|
||||
#ifdef _WIN32
|
||||
return _fileno(pimpl->fp);
|
||||
#else
|
||||
#if defined(fileno)
|
||||
return fileno(pimpl->fp);
|
||||
#else
|
||||
return ::fileno(pimpl->fp);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
|
||||
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
|
||||
|
||||
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
|
||||
|
||||
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
|
||||
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
|
||||
|
||||
// llama_mmap
|
||||
|
||||
struct llama_mmap::impl {
|
||||
#ifdef _POSIX_MAPPED_FILES
|
||||
std::vector<std::pair<size_t, size_t>> mapped_fragments;
|
||||
|
||||
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
||||
size = file->size();
|
||||
int fd = file->file_id();
|
||||
int flags = MAP_SHARED;
|
||||
if (numa) { prefetch = 0; }
|
||||
#ifdef __linux__
|
||||
if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
|
||||
LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
if (prefetch) { flags |= MAP_POPULATE; }
|
||||
#endif
|
||||
addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0);
|
||||
if (addr == MAP_FAILED) {
|
||||
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
|
||||
}
|
||||
|
||||
if (prefetch > 0) {
|
||||
if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) {
|
||||
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
}
|
||||
if (numa) {
|
||||
if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) {
|
||||
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
}
|
||||
|
||||
mapped_fragments.emplace_back(0, file->size());
|
||||
}
|
||||
|
||||
static void align_range(size_t * first, size_t * last, size_t page_size) {
|
||||
size_t offset_in_page = *first & (page_size - 1);
|
||||
size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
|
||||
*first += offset_to_page;
|
||||
|
||||
*last = *last & ~(page_size - 1);
|
||||
|
||||
if (*last <= *first) {
|
||||
*last = *first;
|
||||
}
|
||||
}
|
||||
|
||||
void unmap_fragment(size_t first, size_t last) {
|
||||
int page_size = sysconf(_SC_PAGESIZE);
|
||||
align_range(&first, &last, page_size);
|
||||
size_t len = last - first;
|
||||
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(first % page_size == 0);
|
||||
GGML_ASSERT(last % page_size == 0);
|
||||
GGML_ASSERT(last > first);
|
||||
|
||||
void * next_page_start = (uint8_t *) addr + first;
|
||||
|
||||
if (munmap(next_page_start, len)) {
|
||||
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, size_t>> new_mapped_fragments;
|
||||
for (const auto & frag : mapped_fragments) {
|
||||
if (frag.first < first && frag.second > last) {
|
||||
new_mapped_fragments.emplace_back(frag.first, first);
|
||||
new_mapped_fragments.emplace_back(last, frag.second);
|
||||
} else if (frag.first < first && frag.second > first) {
|
||||
new_mapped_fragments.emplace_back(frag.first, first);
|
||||
} else if (frag.first < last && frag.second > last) {
|
||||
new_mapped_fragments.emplace_back(last, frag.second);
|
||||
} else if (frag.first >= first && frag.second <= last) {
|
||||
} else {
|
||||
new_mapped_fragments.push_back(frag);
|
||||
}
|
||||
}
|
||||
mapped_fragments = std::move(new_mapped_fragments);
|
||||
}
|
||||
|
||||
~impl() {
|
||||
for (const auto & frag : mapped_fragments) {
|
||||
if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
|
||||
LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
||||
GGML_UNUSED(numa);
|
||||
|
||||
size = file->size();
|
||||
|
||||
HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id());
|
||||
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
|
||||
if (hMapping == NULL) {
|
||||
DWORD error = GetLastError();
|
||||
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
|
||||
}
|
||||
|
||||
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
DWORD error = GetLastError();
|
||||
CloseHandle(hMapping);
|
||||
|
||||
if (addr == NULL) {
|
||||
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
|
||||
}
|
||||
|
||||
if (prefetch > 0) {
|
||||
#if _WIN32_WINNT >= 0x602
|
||||
BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
|
||||
HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
|
||||
|
||||
pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory");
|
||||
|
||||
if (pPrefetchVirtualMemory) {
|
||||
WIN32_MEMORY_RANGE_ENTRY range;
|
||||
range.VirtualAddress = addr;
|
||||
range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
|
||||
if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
||||
LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
throw std::runtime_error("PrefetchVirtualMemory unavailable");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void unmap_fragment(size_t first, size_t last) {
|
||||
GGML_UNUSED(first);
|
||||
GGML_UNUSED(last);
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
||||
GGML_UNUSED(file);
|
||||
GGML_UNUSED(prefetch);
|
||||
GGML_UNUSED(numa);
|
||||
|
||||
throw std::runtime_error("mmap not supported");
|
||||
}
|
||||
|
||||
void unmap_fragment(size_t first, size_t last) {
|
||||
GGML_UNUSED(first);
|
||||
GGML_UNUSED(last);
|
||||
|
||||
throw std::runtime_error("mmap not supported");
|
||||
}
|
||||
#endif
|
||||
|
||||
void * addr;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa) : pimpl(std::make_unique<impl>(file, prefetch, numa)) {}
|
||||
llama_mmap::~llama_mmap() = default;
|
||||
|
||||
size_t llama_mmap::size() const { return pimpl->size; }
|
||||
void * llama_mmap::addr() const { return pimpl->addr; }
|
||||
|
||||
void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); }
|
||||
|
||||
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
|
||||
const bool llama_mmap::SUPPORTED = true;
|
||||
#else
|
||||
const bool llama_mmap::SUPPORTED = false;
|
||||
#endif
|
||||
|
||||
// llama_mlock
|
||||
|
||||
struct llama_mlock::impl {
|
||||
#ifdef _POSIX_MEMLOCK_RANGE
|
||||
static size_t lock_granularity() {
|
||||
return (size_t) sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
|
||||
bool raw_lock(const void * addr, size_t size) const {
|
||||
if (!mlock(addr, size)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
|
||||
#else
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
|
||||
#endif
|
||||
|
||||
char* errmsg = std::strerror(errno);
|
||||
bool suggest = (errno == ENOMEM);
|
||||
|
||||
struct rlimit lock_limit;
|
||||
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
||||
suggest = false;
|
||||
}
|
||||
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
||||
suggest = false;
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
||||
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
||||
return false;
|
||||
}
|
||||
|
||||
static void raw_unlock(void * addr, size_t size) {
|
||||
if (munlock(addr, size)) {
|
||||
LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static size_t lock_granularity() {
|
||||
SYSTEM_INFO si;
|
||||
GetSystemInfo(&si);
|
||||
return (size_t) si.dwPageSize;
|
||||
}
|
||||
|
||||
bool raw_lock(void * ptr, size_t len) const {
|
||||
for (int tries = 1; ; tries++) {
|
||||
if (VirtualLock(ptr, len)) {
|
||||
return true;
|
||||
}
|
||||
if (tries == 2) {
|
||||
LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
||||
len, size, llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
SIZE_T min_ws_size, max_ws_size;
|
||||
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
||||
LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
size_t increment = len + 1048576;
|
||||
min_ws_size += increment;
|
||||
max_ws_size += increment;
|
||||
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
||||
LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void raw_unlock(void * ptr, size_t len) {
|
||||
if (!VirtualUnlock(ptr, len)) {
|
||||
LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static size_t lock_granularity() {
|
||||
return (size_t) 65536;
|
||||
}
|
||||
|
||||
bool raw_lock(const void * addr, size_t len) const {
|
||||
LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
static void raw_unlock(const void * addr, size_t len) {}
|
||||
#endif
|
||||
|
||||
impl() : addr(NULL), size(0), failed_already(false) {}
|
||||
|
||||
void init(void * ptr) {
|
||||
GGML_ASSERT(addr == NULL && size == 0);
|
||||
addr = ptr;
|
||||
}
|
||||
|
||||
void grow_to(size_t target_size) {
|
||||
GGML_ASSERT(addr);
|
||||
if (failed_already) {
|
||||
return;
|
||||
}
|
||||
size_t granularity = lock_granularity();
|
||||
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
||||
if (target_size > size) {
|
||||
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
||||
size = target_size;
|
||||
} else {
|
||||
failed_already = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void * addr;
|
||||
size_t size;
|
||||
|
||||
bool failed_already;
|
||||
};
|
||||
|
||||
llama_mlock::llama_mlock() : pimpl(std::make_unique<impl>()) {}
|
||||
llama_mlock::~llama_mlock() = default;
|
||||
|
||||
void llama_mlock::init(void * ptr) { pimpl->init(ptr); }
|
||||
void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); }
|
||||
|
||||
#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32)
|
||||
const bool llama_mlock::SUPPORTED = true;
|
||||
#else
|
||||
const bool llama_mlock::SUPPORTED = false;
|
||||
#endif
|
||||
|
||||
size_t llama_path_max() {
|
||||
return PATH_MAX;
|
||||
}
|
67
examples/talk-llama/llama-mmap.h
Normal file
67
examples/talk-llama/llama-mmap.h
Normal file
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
struct llama_file;
|
||||
struct llama_mmap;
|
||||
struct llama_mlock;
|
||||
|
||||
using llama_files = std::vector<std::unique_ptr<llama_file>>;
|
||||
using llama_mmaps = std::vector<std::unique_ptr<llama_mmap>>;
|
||||
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
|
||||
|
||||
struct llama_file {
|
||||
llama_file(const char * fname, const char * mode);
|
||||
~llama_file();
|
||||
|
||||
size_t tell() const;
|
||||
size_t size() const;
|
||||
|
||||
int file_id() const; // fileno overload
|
||||
|
||||
void seek(size_t offset, int whence) const;
|
||||
|
||||
void read_raw(void * ptr, size_t len) const;
|
||||
uint32_t read_u32() const;
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const;
|
||||
void write_u32(uint32_t val) const;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
struct llama_mmap {
|
||||
llama_mmap(const llama_mmap &) = delete;
|
||||
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false);
|
||||
~llama_mmap();
|
||||
|
||||
size_t size() const;
|
||||
void * addr() const;
|
||||
|
||||
void unmap_fragment(size_t first, size_t last);
|
||||
|
||||
static const bool SUPPORTED;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
struct llama_mlock {
|
||||
llama_mlock();
|
||||
~llama_mlock();
|
||||
|
||||
void init(void * ptr);
|
||||
void grow_to(size_t target_size);
|
||||
|
||||
static const bool SUPPORTED;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
size_t llama_path_max();
|
1124
examples/talk-llama/llama-model-loader.cpp
Normal file
1124
examples/talk-llama/llama-model-loader.cpp
Normal file
File diff suppressed because it is too large
Load Diff
167
examples/talk-llama/llama-model-loader.h
Normal file
167
examples/talk-llama/llama-model-loader.h
Normal file
@ -0,0 +1,167 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-arch.h"
|
||||
#include "llama-mmap.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
|
||||
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
|
||||
|
||||
enum llama_fver {
|
||||
GGUF_FILE_VERSION_V1 = 1,
|
||||
GGUF_FILE_VERSION_V2 = 2,
|
||||
GGUF_FILE_VERSION_V3 = 3,
|
||||
};
|
||||
|
||||
const char * llama_file_version_name(llama_fver version);
|
||||
|
||||
struct llama_model_loader {
|
||||
// Holds information on a model weight
|
||||
struct llama_tensor_weight {
|
||||
uint16_t idx; // source file index
|
||||
size_t offs; // tensor data offset in the original file
|
||||
|
||||
ggml_tensor * tensor;
|
||||
|
||||
llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
|
||||
const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor));
|
||||
if (tensor_idx < 0) {
|
||||
throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor)));
|
||||
}
|
||||
|
||||
offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
|
||||
if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) {
|
||||
throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// custom comparator to sort weights more nicely by layer
|
||||
struct weight_name_comparer {
|
||||
bool operator()(const std::string & a, const std::string & b) const {
|
||||
int a_layer = -1;
|
||||
int b_layer = -1;
|
||||
sscanf(a.c_str(), "blk.%d.", &a_layer);
|
||||
sscanf(b.c_str(), "blk.%d.", &b_layer);
|
||||
if (a_layer != b_layer) {
|
||||
return a_layer < b_layer;
|
||||
}
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
static const int TENSOR_NOT_REQUIRED = 1;
|
||||
static const int TENSOR_DUPLICATED = 2;
|
||||
|
||||
int n_kv = 0;
|
||||
int n_tensors = 0;
|
||||
int n_created = 0;
|
||||
|
||||
uint64_t n_elements = 0;
|
||||
size_t n_bytes = 0;
|
||||
|
||||
bool use_mmap = false;
|
||||
bool check_tensors;
|
||||
|
||||
llama_files files;
|
||||
llama_ftype ftype;
|
||||
llama_fver fver;
|
||||
|
||||
llama_mmaps mappings;
|
||||
|
||||
std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
|
||||
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
|
||||
|
||||
gguf_context_ptr meta;
|
||||
std::vector<ggml_context_ptr> contexts;
|
||||
|
||||
std::string arch_name;
|
||||
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
||||
|
||||
size_t size_done = 0;
|
||||
size_t size_data = 0;
|
||||
std::vector<std::pair<size_t, size_t>> mmaps_used;
|
||||
|
||||
llama_model_loader(
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
||||
bool use_mmap,
|
||||
bool check_tensors,
|
||||
const struct llama_model_kv_override * param_overrides_p);
|
||||
|
||||
template<typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
||||
get_arr_n(const std::string & key, T & result, bool required = true);
|
||||
|
||||
template<typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
||||
get_arr_n(enum llm_kv kid, T & result, bool required = true);
|
||||
|
||||
template<typename T>
|
||||
bool get_arr(const std::string & key, std::vector<T> & result, bool required = true);
|
||||
|
||||
template<typename T, size_t N_MAX>
|
||||
bool get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required = true);
|
||||
|
||||
template<typename T>
|
||||
bool get_arr(enum llm_kv kid, T & result, bool required = true);
|
||||
|
||||
template<typename T>
|
||||
bool get_key(const std::string & key, T & result, bool required = true);
|
||||
|
||||
template<typename T>
|
||||
bool get_key(enum llm_kv kid, T & result, bool required = true);
|
||||
|
||||
template<typename T, size_t N_MAX>
|
||||
bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required = true);
|
||||
|
||||
template<typename T>
|
||||
bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
|
||||
|
||||
std::string get_arch_name() const;
|
||||
|
||||
enum llm_arch get_arch() const;
|
||||
|
||||
const llama_tensor_weight * get_weight(const char * name) const;
|
||||
|
||||
const llama_tensor_weight & require_weight(const char * name) const;
|
||||
|
||||
struct ggml_tensor * get_tensor_meta(const char * name) const;
|
||||
|
||||
struct ggml_tensor * require_tensor_meta(const std::string & name) const;
|
||||
|
||||
const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const;
|
||||
|
||||
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0);
|
||||
|
||||
struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true);
|
||||
|
||||
void done_getting_tensors() const;
|
||||
|
||||
void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr);
|
||||
|
||||
void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const;
|
||||
|
||||
// for backwards compatibility, does not support ggml-backend
|
||||
void load_data_for(struct ggml_tensor * cur) const;
|
||||
|
||||
// Returns false if cancelled by progress_callback
|
||||
bool load_all_data(
|
||||
struct ggml_context * ctx,
|
||||
llama_buf_map & bufs,
|
||||
llama_mlocks * lmlocks,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data);
|
||||
|
||||
std::string ftype_name() const;
|
||||
|
||||
void print_info() const;
|
||||
};
|
4021
examples/talk-llama/llama-model.cpp
Normal file
4021
examples/talk-llama/llama-model.cpp
Normal file
File diff suppressed because it is too large
Load Diff
370
examples/talk-llama/llama-model.h
Normal file
370
examples/talk-llama/llama-model.h
Normal file
@ -0,0 +1,370 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-arch.h"
|
||||
#include "llama-hparams.h"
|
||||
#include "llama-vocab.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_model_loader;
|
||||
|
||||
// available models
|
||||
enum llm_type {
|
||||
LLM_TYPE_UNKNOWN,
|
||||
LLM_TYPE_14M,
|
||||
LLM_TYPE_17M,
|
||||
LLM_TYPE_22M,
|
||||
LLM_TYPE_33M,
|
||||
LLM_TYPE_60M,
|
||||
LLM_TYPE_70M,
|
||||
LLM_TYPE_80M,
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_220M,
|
||||
LLM_TYPE_250M,
|
||||
LLM_TYPE_270M,
|
||||
LLM_TYPE_335M,
|
||||
LLM_TYPE_410M,
|
||||
LLM_TYPE_450M,
|
||||
LLM_TYPE_770M,
|
||||
LLM_TYPE_780M,
|
||||
LLM_TYPE_0_5B,
|
||||
LLM_TYPE_1B,
|
||||
LLM_TYPE_1_3B,
|
||||
LLM_TYPE_1_4B,
|
||||
LLM_TYPE_1_5B,
|
||||
LLM_TYPE_1_6B,
|
||||
LLM_TYPE_2B,
|
||||
LLM_TYPE_2_8B,
|
||||
LLM_TYPE_3B,
|
||||
LLM_TYPE_4B,
|
||||
LLM_TYPE_6B,
|
||||
LLM_TYPE_6_9B,
|
||||
LLM_TYPE_7B,
|
||||
LLM_TYPE_8B,
|
||||
LLM_TYPE_9B,
|
||||
LLM_TYPE_11B,
|
||||
LLM_TYPE_12B,
|
||||
LLM_TYPE_13B,
|
||||
LLM_TYPE_14B,
|
||||
LLM_TYPE_15B,
|
||||
LLM_TYPE_16B,
|
||||
LLM_TYPE_20B,
|
||||
LLM_TYPE_30B,
|
||||
LLM_TYPE_32B,
|
||||
LLM_TYPE_34B,
|
||||
LLM_TYPE_35B,
|
||||
LLM_TYPE_40B,
|
||||
LLM_TYPE_65B,
|
||||
LLM_TYPE_70B,
|
||||
LLM_TYPE_236B,
|
||||
LLM_TYPE_314B,
|
||||
LLM_TYPE_671B,
|
||||
LLM_TYPE_SMALL,
|
||||
LLM_TYPE_MEDIUM,
|
||||
LLM_TYPE_LARGE,
|
||||
LLM_TYPE_XL,
|
||||
LLM_TYPE_A1_7B,
|
||||
LLM_TYPE_A2_7B,
|
||||
LLM_TYPE_8x7B,
|
||||
LLM_TYPE_8x22B,
|
||||
LLM_TYPE_16x12B,
|
||||
LLM_TYPE_16x3_8B,
|
||||
LLM_TYPE_10B_128x3_66B,
|
||||
LLM_TYPE_57B_A14B,
|
||||
LLM_TYPE_27B,
|
||||
};
|
||||
|
||||
struct llama_layer_posnet {
|
||||
// resnet
|
||||
struct ggml_tensor * norm1 = nullptr;
|
||||
struct ggml_tensor * norm1_b = nullptr;
|
||||
|
||||
struct ggml_tensor * conv1 = nullptr;
|
||||
struct ggml_tensor * conv1_b = nullptr;
|
||||
|
||||
struct ggml_tensor * norm2 = nullptr;
|
||||
struct ggml_tensor * norm2_b = nullptr;
|
||||
|
||||
struct ggml_tensor * conv2 = nullptr;
|
||||
struct ggml_tensor * conv2_b = nullptr;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * attn_norm = nullptr;
|
||||
struct ggml_tensor * attn_norm_b = nullptr;
|
||||
|
||||
struct ggml_tensor * attn_q = nullptr;
|
||||
struct ggml_tensor * attn_q_b = nullptr;
|
||||
|
||||
struct ggml_tensor * attn_k = nullptr;
|
||||
struct ggml_tensor * attn_k_b = nullptr;
|
||||
|
||||
struct ggml_tensor * attn_v = nullptr;
|
||||
struct ggml_tensor * attn_v_b = nullptr;
|
||||
|
||||
struct ggml_tensor * attn_o = nullptr;
|
||||
struct ggml_tensor * attn_o_b = nullptr;
|
||||
|
||||
// normalize
|
||||
struct ggml_tensor * norm = nullptr;
|
||||
struct ggml_tensor * norm_b = nullptr;
|
||||
};
|
||||
|
||||
struct llama_layer_convnext {
|
||||
struct ggml_tensor * dw = nullptr;
|
||||
struct ggml_tensor * dw_b = nullptr;
|
||||
|
||||
struct ggml_tensor * norm = nullptr;
|
||||
struct ggml_tensor * norm_b = nullptr;
|
||||
|
||||
struct ggml_tensor * pw1 = nullptr;
|
||||
struct ggml_tensor * pw1_b = nullptr;
|
||||
|
||||
struct ggml_tensor * pw2 = nullptr;
|
||||
struct ggml_tensor * pw2_b = nullptr;
|
||||
|
||||
struct ggml_tensor * gamma = nullptr;
|
||||
};
|
||||
|
||||
struct llama_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * attn_norm = nullptr;
|
||||
struct ggml_tensor * attn_norm_b = nullptr;
|
||||
struct ggml_tensor * attn_norm_2 = nullptr;
|
||||
struct ggml_tensor * attn_norm_2_b = nullptr;
|
||||
struct ggml_tensor * attn_q_norm = nullptr;
|
||||
struct ggml_tensor * attn_q_norm_b = nullptr;
|
||||
struct ggml_tensor * attn_k_norm = nullptr;
|
||||
struct ggml_tensor * attn_k_norm_b = nullptr;
|
||||
struct ggml_tensor * attn_out_norm = nullptr;
|
||||
struct ggml_tensor * attn_out_norm_b = nullptr;
|
||||
struct ggml_tensor * attn_q_a_norm = nullptr;
|
||||
struct ggml_tensor * attn_kv_a_norm = nullptr;
|
||||
struct ggml_tensor * attn_sub_norm = nullptr;
|
||||
struct ggml_tensor * attn_post_norm = nullptr;
|
||||
struct ggml_tensor * ffn_sub_norm = nullptr;
|
||||
struct ggml_tensor * attn_norm_cross = nullptr;
|
||||
struct ggml_tensor * attn_norm_enc = nullptr;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq = nullptr;
|
||||
struct ggml_tensor * wk = nullptr;
|
||||
struct ggml_tensor * wv = nullptr;
|
||||
struct ggml_tensor * wo = nullptr;
|
||||
struct ggml_tensor * wqkv = nullptr;
|
||||
struct ggml_tensor * wq_a = nullptr;
|
||||
struct ggml_tensor * wq_b = nullptr;
|
||||
struct ggml_tensor * wkv_a_mqa = nullptr;
|
||||
struct ggml_tensor * wkv_b = nullptr;
|
||||
struct ggml_tensor * wq_cross = nullptr;
|
||||
struct ggml_tensor * wk_cross = nullptr;
|
||||
struct ggml_tensor * wv_cross = nullptr;
|
||||
struct ggml_tensor * wo_cross = nullptr;
|
||||
struct ggml_tensor * wq_enc = nullptr;
|
||||
struct ggml_tensor * wk_enc = nullptr;
|
||||
struct ggml_tensor * wv_enc = nullptr;
|
||||
struct ggml_tensor * wo_enc = nullptr;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq = nullptr;
|
||||
struct ggml_tensor * bk = nullptr;
|
||||
struct ggml_tensor * bv = nullptr;
|
||||
struct ggml_tensor * bo = nullptr;
|
||||
struct ggml_tensor * bqkv = nullptr;
|
||||
|
||||
// relative position bias
|
||||
struct ggml_tensor * attn_rel_b = nullptr;
|
||||
struct ggml_tensor * attn_rel_b_enc = nullptr;
|
||||
struct ggml_tensor * attn_rel_b_cross = nullptr;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ffn_norm = nullptr;
|
||||
struct ggml_tensor * ffn_norm_b = nullptr;
|
||||
struct ggml_tensor * ffn_post_norm = nullptr;
|
||||
struct ggml_tensor * layer_out_norm = nullptr;
|
||||
struct ggml_tensor * layer_out_norm_b = nullptr;
|
||||
struct ggml_tensor * ffn_norm_exps = nullptr;
|
||||
struct ggml_tensor * ffn_norm_enc = nullptr;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_gate = nullptr; // w1
|
||||
struct ggml_tensor * ffn_down = nullptr; // w2
|
||||
struct ggml_tensor * ffn_up = nullptr; // w3
|
||||
struct ggml_tensor * ffn_gate_enc = nullptr;
|
||||
struct ggml_tensor * ffn_down_enc = nullptr;
|
||||
struct ggml_tensor * ffn_up_enc = nullptr;
|
||||
|
||||
// ff MoE
|
||||
struct ggml_tensor * ffn_gate_inp = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps = nullptr;
|
||||
|
||||
// ff shared expert (shexp)
|
||||
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
|
||||
struct ggml_tensor * ffn_gate_shexp = nullptr;
|
||||
struct ggml_tensor * ffn_down_shexp = nullptr;
|
||||
struct ggml_tensor * ffn_up_shexp = nullptr;
|
||||
|
||||
// ff bias
|
||||
struct ggml_tensor * ffn_gate_b = nullptr;
|
||||
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
||||
struct ggml_tensor * ffn_up_b = nullptr; // b3
|
||||
struct ggml_tensor * ffn_act = nullptr;
|
||||
struct ggml_tensor * ffn_exp_probs_b = nullptr;
|
||||
|
||||
// mamba proj
|
||||
struct ggml_tensor * ssm_in = nullptr;
|
||||
struct ggml_tensor * ssm_x = nullptr;
|
||||
struct ggml_tensor * ssm_dt = nullptr;
|
||||
struct ggml_tensor * ssm_out = nullptr;
|
||||
|
||||
// mamba
|
||||
struct ggml_tensor * ssm_conv1d = nullptr;
|
||||
struct ggml_tensor * ssm_a = nullptr;
|
||||
struct ggml_tensor * ssm_d = nullptr;
|
||||
|
||||
// mamba bias
|
||||
struct ggml_tensor * ssm_conv1d_b = nullptr;
|
||||
struct ggml_tensor * ssm_dt_b = nullptr;
|
||||
|
||||
// rwkv
|
||||
struct ggml_tensor * time_mix_w1 = nullptr;
|
||||
struct ggml_tensor * time_mix_w2 = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_x = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_w = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_k = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_v = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_r = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_g = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_fused = nullptr;
|
||||
|
||||
struct ggml_tensor * time_mix_first = nullptr;
|
||||
struct ggml_tensor * time_mix_decay = nullptr;
|
||||
struct ggml_tensor * time_mix_decay_w1 = nullptr;
|
||||
struct ggml_tensor * time_mix_decay_w2 = nullptr;
|
||||
struct ggml_tensor * time_mix_key = nullptr;
|
||||
struct ggml_tensor * time_mix_key_b = nullptr;
|
||||
struct ggml_tensor * time_mix_value = nullptr;
|
||||
struct ggml_tensor * time_mix_value_b = nullptr;
|
||||
struct ggml_tensor * time_mix_receptance = nullptr;
|
||||
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
||||
struct ggml_tensor * time_mix_gate = nullptr;
|
||||
|
||||
struct ggml_tensor * time_mix_ln = nullptr;
|
||||
struct ggml_tensor * time_mix_ln_b = nullptr;
|
||||
struct ggml_tensor * time_mix_output = nullptr;
|
||||
|
||||
struct ggml_tensor * channel_mix_lerp_k = nullptr;
|
||||
struct ggml_tensor * channel_mix_lerp_r = nullptr;
|
||||
|
||||
struct ggml_tensor * channel_mix_key = nullptr;
|
||||
struct ggml_tensor * channel_mix_receptance = nullptr;
|
||||
struct ggml_tensor * channel_mix_value = nullptr;
|
||||
|
||||
// long rope factors
|
||||
struct ggml_tensor * rope_long = nullptr;
|
||||
struct ggml_tensor * rope_short = nullptr;
|
||||
struct ggml_tensor * rope_freqs = nullptr;
|
||||
|
||||
// bitnet scale
|
||||
struct ggml_tensor * wq_scale = nullptr;
|
||||
struct ggml_tensor * wk_scale = nullptr;
|
||||
struct ggml_tensor * wv_scale = nullptr;
|
||||
struct ggml_tensor * wo_scale = nullptr;
|
||||
struct ggml_tensor * ffn_gate_scale = nullptr;
|
||||
struct ggml_tensor * ffn_up_scale = nullptr;
|
||||
struct ggml_tensor * ffn_down_scale = nullptr;
|
||||
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
struct llama_layer_convnext convnext;
|
||||
};
|
||||
|
||||
struct llama_model {
|
||||
llm_type type = LLM_TYPE_UNKNOWN;
|
||||
llm_arch arch = LLM_ARCH_UNKNOWN;
|
||||
|
||||
std::string name = "n/a";
|
||||
|
||||
llama_hparams hparams = {};
|
||||
llama_vocab vocab;
|
||||
|
||||
struct ggml_tensor * tok_embd = nullptr;
|
||||
struct ggml_tensor * type_embd = nullptr;
|
||||
struct ggml_tensor * pos_embd = nullptr;
|
||||
struct ggml_tensor * tok_norm = nullptr;
|
||||
struct ggml_tensor * tok_norm_b = nullptr;
|
||||
|
||||
struct ggml_tensor * output_norm = nullptr;
|
||||
struct ggml_tensor * output_norm_b = nullptr;
|
||||
struct ggml_tensor * output = nullptr;
|
||||
struct ggml_tensor * output_b = nullptr;
|
||||
struct ggml_tensor * output_norm_enc = nullptr;
|
||||
|
||||
// classifier
|
||||
struct ggml_tensor * cls = nullptr;
|
||||
struct ggml_tensor * cls_b = nullptr;
|
||||
struct ggml_tensor * cls_out = nullptr;
|
||||
struct ggml_tensor * cls_out_b = nullptr;
|
||||
|
||||
struct ggml_tensor * conv1d = nullptr;
|
||||
struct ggml_tensor * conv1d_b = nullptr;
|
||||
|
||||
std::vector<llama_layer> layers;
|
||||
|
||||
llama_model_params params;
|
||||
|
||||
// gguf metadata
|
||||
std::unordered_map<std::string, std::string> gguf_kv;
|
||||
|
||||
// list of devices used in this model
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
|
||||
// for quantize-stats only
|
||||
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
|
||||
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
||||
explicit llama_model(const struct llama_model_params & params);
|
||||
~llama_model();
|
||||
|
||||
void load_stats (llama_model_loader & ml);
|
||||
void load_arch (llama_model_loader & ml);
|
||||
void load_hparams(llama_model_loader & ml);
|
||||
void load_vocab (llama_model_loader & ml);
|
||||
bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
|
||||
|
||||
std::string arch_name() const;
|
||||
std::string type_name() const;
|
||||
|
||||
std::string desc() const;
|
||||
|
||||
size_t size() const;
|
||||
size_t max_nodes() const;
|
||||
size_t n_devices() const;
|
||||
|
||||
// total number of parameters in the model
|
||||
uint64_t n_elements() const;
|
||||
|
||||
void print_info() const;
|
||||
|
||||
ggml_backend_dev_t dev_layer(int il) const;
|
||||
ggml_backend_dev_t dev_output() const;
|
||||
|
||||
ggml_backend_buffer_type_t select_buft(int il) const;
|
||||
|
||||
const struct ggml_tensor * get_tensor(const char * name) const;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
const char * llm_type_name(llm_type type);
|
934
examples/talk-llama/llama-quant.cpp
Normal file
934
examples/talk-llama/llama-quant.cpp
Normal file
@ -0,0 +1,934 @@
|
||||
#include "llama-quant.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-model-loader.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <cinttypes>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
static void zeros(std::ofstream & file, size_t n) {
|
||||
char zero = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
file.write(&zero, 1);
|
||||
}
|
||||
}
|
||||
|
||||
struct quantize_state_impl {
|
||||
const llama_model & model;
|
||||
const llama_model_quantize_params * params;
|
||||
|
||||
int n_attention_wv = 0;
|
||||
int n_ffn_down = 0;
|
||||
int n_ffn_gate = 0;
|
||||
int n_ffn_up = 0;
|
||||
int i_attention_wv = 0;
|
||||
int i_ffn_down = 0;
|
||||
int i_ffn_gate = 0;
|
||||
int i_ffn_up = 0;
|
||||
|
||||
int n_k_quantized = 0;
|
||||
int n_fallback = 0;
|
||||
|
||||
bool has_imatrix = false;
|
||||
|
||||
// used to figure out if a model shares tok_embd with the output weight
|
||||
bool has_output = false;
|
||||
|
||||
quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
|
||||
: model(model)
|
||||
, params(params)
|
||||
{}
|
||||
};
|
||||
|
||||
static void llama_tensor_dequantize_impl(
|
||||
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||
const size_t nelements, const int nthread
|
||||
) {
|
||||
if (output.size() < nelements) {
|
||||
output.resize(nelements);
|
||||
}
|
||||
float * f32_output = (float *) output.data();
|
||||
|
||||
const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
if (qtype->to_float == NULL) {
|
||||
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
|
||||
}
|
||||
} else if (tensor->type != GGML_TYPE_F16 &&
|
||||
tensor->type != GGML_TYPE_BF16) {
|
||||
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
|
||||
}
|
||||
|
||||
if (nthread < 2) {
|
||||
if (tensor->type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
|
||||
} else if (tensor->type == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
|
||||
} else if (ggml_is_quantized(tensor->type)) {
|
||||
qtype->to_float(tensor->data, f32_output, nelements);
|
||||
} else {
|
||||
GGML_ABORT("fatal error"); // unreachable
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
size_t block_size;
|
||||
if (tensor->type == GGML_TYPE_F16 ||
|
||||
tensor->type == GGML_TYPE_BF16) {
|
||||
block_size = 1;
|
||||
} else {
|
||||
block_size = (size_t)ggml_blck_size(tensor->type);
|
||||
}
|
||||
|
||||
size_t block_size_bytes = ggml_type_size(tensor->type);
|
||||
|
||||
GGML_ASSERT(nelements % block_size == 0);
|
||||
size_t nblocks = nelements / block_size;
|
||||
size_t blocks_per_thread = nblocks / nthread;
|
||||
size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
|
||||
|
||||
size_t in_buff_offs = 0;
|
||||
size_t out_buff_offs = 0;
|
||||
|
||||
for (int tnum = 0; tnum < nthread; tnum++) {
|
||||
size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
|
||||
size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
|
||||
size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
|
||||
|
||||
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
|
||||
if (typ == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
|
||||
} else if (typ == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
|
||||
} else {
|
||||
qtype->to_float(inbuf, outbuf, nels);
|
||||
}
|
||||
};
|
||||
workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
|
||||
in_buff_offs += thr_block_bytes;
|
||||
out_buff_offs += thr_elems;
|
||||
}
|
||||
for (auto & w : workers) { w.join(); }
|
||||
workers.clear();
|
||||
}
|
||||
|
||||
static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
||||
const llm_arch arch = qs.model.arch;
|
||||
const auto tn = LLM_TN(arch);
|
||||
|
||||
auto use_more_bits = [](int i_layer, int n_layers) -> bool {
|
||||
return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
|
||||
};
|
||||
const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
|
||||
auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
|
||||
if (n_expert > 1) {
|
||||
// Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
|
||||
// sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
|
||||
// for getting the current layer as I initially thought, and we need to resort to parsing the
|
||||
// tensor name.
|
||||
if (sscanf(name, "blk.%d.", &i_layer) != 1) {
|
||||
throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
|
||||
}
|
||||
if (i_layer < 0 || i_layer >= n_layer) {
|
||||
throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
|
||||
}
|
||||
}
|
||||
return std::make_pair(i_layer, n_layer);
|
||||
};
|
||||
|
||||
// for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
|
||||
// with the quantization of the output tensor
|
||||
if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
|
||||
if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
|
||||
new_type = qs.params->output_tensor_type;
|
||||
} else {
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if (new_type != GGML_TYPE_Q8_0) {
|
||||
new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
}
|
||||
} else if (name == "token_embd.weight") {
|
||||
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
||||
new_type = qs.params->token_embedding_type;
|
||||
} else {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
||||
new_type = GGML_TYPE_Q2_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
|
||||
new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
}
|
||||
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
||||
if (name.find("attn_v.weight") != std::string::npos) {
|
||||
if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
|
||||
else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
|
||||
++qs.i_attention_wv;
|
||||
}
|
||||
else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (name.find("ffn_down") != std::string::npos) {
|
||||
if (qs.i_ffn_down < qs.n_ffn_down/8) {
|
||||
new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
|
||||
}
|
||||
++qs.i_ffn_down;
|
||||
}
|
||||
else if (name.find("attn_output.weight") != std::string::npos) {
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
} else {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
|
||||
}
|
||||
}
|
||||
} else if (name.find("attn_v.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
|
||||
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
|
||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
|
||||
use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
|
||||
if (qs.model.type == LLM_TYPE_70B) {
|
||||
// In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
|
||||
// 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
|
||||
// nearly negligible increase in model size by quantizing this tensor with more bits:
|
||||
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
// TODO: explore better strategies
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
++qs.i_attention_wv;
|
||||
} else if (name.find("attn_k.weight") != std::string::npos) {
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
|
||||
// TODO: explore better strategies
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ2_S;
|
||||
}
|
||||
} else if (name.find("attn_q.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||
new_type = GGML_TYPE_IQ2_S;
|
||||
}
|
||||
} else if (name.find("ffn_down") != std::string::npos) {
|
||||
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
|
||||
int i_layer = info.first, n_layer = info.second;
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
|
||||
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
|
||||
new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
|
||||
new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
|
||||
: arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
|
||||
: GGML_TYPE_Q3_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
|
||||
(qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
|
||||
new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
|
||||
if (arch == LLM_ARCH_FALCON) {
|
||||
new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
|
||||
use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
|
||||
} else {
|
||||
if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
}
|
||||
else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
|
||||
&& qs.has_imatrix && i_layer < n_layer/8) {
|
||||
// Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
|
||||
// We only do it when an imatrix is provided because a) we want to make sure that one can always get the
|
||||
// same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
|
||||
new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
|
||||
}
|
||||
++qs.i_ffn_down;
|
||||
} else if (name.find("attn_output.weight") != std::string::npos) {
|
||||
if (arch != LLM_ARCH_FALCON) {
|
||||
if (qs.model.hparams.n_expert == 8) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
|
||||
new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
} else {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
} else {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
}
|
||||
else if (name.find("attn_qkv.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
|
||||
new_type = GGML_TYPE_Q4_K;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
else if (name.find("ffn_gate") != std::string::npos) {
|
||||
auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
|
||||
int i_layer = info.first, n_layer = info.second;
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
++qs.i_ffn_gate;
|
||||
}
|
||||
else if (name.find("ffn_up") != std::string::npos) {
|
||||
auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
|
||||
int i_layer = info.first, n_layer = info.second;
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
|
||||
new_type = GGML_TYPE_IQ3_XXS;
|
||||
}
|
||||
++qs.i_ffn_up;
|
||||
}
|
||||
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
//}
|
||||
// IK: let's remove this, else Q2_K is almost the same as Q3_K_S
|
||||
//else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
//}
|
||||
// This can be used to reduce the size of the Q5_K_S model.
|
||||
// The associated PPL increase is fully in line with the size reduction
|
||||
//else {
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
|
||||
//}
|
||||
bool convert_incompatible_tensor = false;
|
||||
{
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t ny = tensor->ne[1];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (nx % qk_k != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
|
||||
convert_incompatible_tensor = true;
|
||||
} else {
|
||||
++qs.n_k_quantized;
|
||||
}
|
||||
}
|
||||
|
||||
if (convert_incompatible_tensor) {
|
||||
switch (new_type) {
|
||||
case GGML_TYPE_TQ1_0:
|
||||
case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
||||
new_type = GGML_TYPE_F16;
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
++qs.n_fallback;
|
||||
}
|
||||
|
||||
return new_type;
|
||||
}
|
||||
|
||||
static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) {
|
||||
if (nthread < 2) {
|
||||
// single-thread
|
||||
size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
|
||||
if (!ggml_validate_row_data(new_type, new_data, new_size)) {
|
||||
throw std::runtime_error("quantized data validation failed");
|
||||
}
|
||||
return new_size;
|
||||
}
|
||||
|
||||
std::mutex mutex;
|
||||
int64_t counter = 0;
|
||||
size_t new_size = 0;
|
||||
bool valid = true;
|
||||
auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
|
||||
nrows, n_per_row, imatrix]() {
|
||||
const int64_t nrows_per_chunk = chunk_size / n_per_row;
|
||||
size_t local_size = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int64_t first_row = counter; counter += nrows_per_chunk;
|
||||
if (first_row >= nrows) {
|
||||
if (local_size > 0) {
|
||||
new_size += local_size;
|
||||
}
|
||||
break;
|
||||
}
|
||||
lock.unlock();
|
||||
const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
|
||||
size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
|
||||
local_size += this_size;
|
||||
|
||||
// validate the quantized data
|
||||
const size_t row_size = ggml_row_size(new_type, n_per_row);
|
||||
void * this_data = (char *) new_data + first_row * row_size;
|
||||
if (!ggml_validate_row_data(new_type, this_data, this_size)) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
for (int it = 0; it < nthread - 1; ++it) {
|
||||
workers.emplace_back(compute);
|
||||
}
|
||||
compute();
|
||||
for (auto & w : workers) { w.join(); }
|
||||
workers.clear();
|
||||
if (!valid) {
|
||||
throw std::runtime_error("quantized data validation failed");
|
||||
}
|
||||
return new_size;
|
||||
}
|
||||
|
||||
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
||||
ggml_type default_type;
|
||||
llama_ftype ftype = params->ftype;
|
||||
|
||||
switch (params->ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
||||
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
||||
|
||||
// K-quants
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_M:
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break;
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break;
|
||||
|
||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
}
|
||||
|
||||
int nthread = params->nthread;
|
||||
|
||||
if (nthread <= 0) {
|
||||
nthread = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
// mmap consistently increases speed Linux, and also increases speed on Windows with
|
||||
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
|
||||
#if defined(__linux__) || defined(_WIN32)
|
||||
constexpr bool use_mmap = true;
|
||||
#else
|
||||
constexpr bool use_mmap = false;
|
||||
#endif
|
||||
|
||||
llama_model_kv_override * kv_overrides = nullptr;
|
||||
if (params->kv_overrides) {
|
||||
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
kv_overrides = v->data();
|
||||
}
|
||||
|
||||
std::vector<std::string> splits = {};
|
||||
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
|
||||
ml.init_mappings(false); // no prefetching
|
||||
|
||||
llama_model model(llama_model_default_params());
|
||||
|
||||
model.load_arch (ml);
|
||||
model.load_hparams(ml);
|
||||
model.load_stats (ml);
|
||||
|
||||
struct quantize_state_impl qs(model, params);
|
||||
|
||||
if (params->only_copy) {
|
||||
ftype = ml.ftype;
|
||||
}
|
||||
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
|
||||
if (params->imatrix) {
|
||||
imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
|
||||
if (imatrix_data) {
|
||||
LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
|
||||
qs.has_imatrix = true;
|
||||
// check imatrix for nans or infs
|
||||
for (const auto & kv : *imatrix_data) {
|
||||
for (float f : kv.second) {
|
||||
if (!std::isfinite(f)) {
|
||||
throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
||||
gguf_context_ptr ctx_out { gguf_init_empty() };
|
||||
|
||||
// copy the KV pairs from the input file
|
||||
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
||||
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
||||
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
|
||||
|
||||
// Remove split metadata
|
||||
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
|
||||
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
|
||||
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
|
||||
|
||||
if (params->kv_overrides) {
|
||||
const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
|
||||
for (const auto & o : overrides) {
|
||||
if (o.key[0] == 0) break;
|
||||
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
|
||||
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
|
||||
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
|
||||
gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
|
||||
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
|
||||
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
|
||||
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
|
||||
gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// make a list of weights
|
||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||
tensors.reserve(ml.weights_map.size());
|
||||
for (const auto & it : ml.weights_map) {
|
||||
tensors.push_back(&it.second);
|
||||
}
|
||||
|
||||
// keep_split requires that the weights are sorted by split index
|
||||
if (params->keep_split) {
|
||||
std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) {
|
||||
if (a->idx == b->idx) {
|
||||
return a->offs < b->offs;
|
||||
}
|
||||
return a->idx < b->idx;
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto * it : tensors) {
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
||||
if (name.find("attn_v.weight") != std::string::npos ||
|
||||
name.find("attn_qkv.weight") != std::string::npos ||
|
||||
name.find("attn_kv_b.weight")!= std::string::npos) {
|
||||
++qs.n_attention_wv;
|
||||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
qs.has_output = true;
|
||||
}
|
||||
}
|
||||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
n_attn_layer *= 3;
|
||||
}
|
||||
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
size_t total_size_org = 0;
|
||||
size_t total_size_new = 0;
|
||||
|
||||
std::vector<std::thread> workers;
|
||||
workers.reserve(nthread);
|
||||
|
||||
int idx = 0;
|
||||
|
||||
std::vector<no_init<uint8_t>> read_data;
|
||||
std::vector<no_init<uint8_t>> work;
|
||||
std::vector<no_init<float>> f32_conv_buf;
|
||||
|
||||
uint16_t n_split = 1;
|
||||
|
||||
// Assume split index is continuous
|
||||
if (params->keep_split) {
|
||||
for (const auto * it : tensors) {
|
||||
n_split = std::max(uint16_t(it->idx + 1), n_split);
|
||||
}
|
||||
}
|
||||
std::vector<gguf_context_ptr> ctx_outs(n_split);
|
||||
ctx_outs[0] = std::move(ctx_out);
|
||||
|
||||
// populate the original tensors so we get an initial meta data
|
||||
for (const auto * it : tensors) {
|
||||
uint16_t i_split = params->keep_split ? it->idx : 0;
|
||||
struct ggml_tensor * tensor = it->tensor;
|
||||
if (!ctx_outs[i_split]) {
|
||||
ctx_outs[i_split].reset(gguf_init_empty());
|
||||
}
|
||||
gguf_add_tensor(ctx_outs[i_split].get(), tensor);
|
||||
}
|
||||
|
||||
// Set split info if needed
|
||||
if (n_split > 1) {
|
||||
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
||||
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
||||
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
||||
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
|
||||
}
|
||||
}
|
||||
|
||||
int cur_split = -1;
|
||||
std::ofstream fout;
|
||||
auto close_ofstream = [&]() {
|
||||
// Write metadata and close file handler
|
||||
if (fout.is_open()) {
|
||||
fout.seekp(0);
|
||||
std::vector<uint8_t> data(gguf_get_meta_size(ctx_outs[cur_split].get()));
|
||||
gguf_get_meta_data(ctx_outs[cur_split].get(), data.data());
|
||||
fout.write((const char *) data.data(), data.size());
|
||||
fout.close();
|
||||
}
|
||||
};
|
||||
auto new_ofstream = [&](int index) {
|
||||
cur_split = index;
|
||||
GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
|
||||
std::string fname = fname_out;
|
||||
if (params->keep_split) {
|
||||
std::vector<char> split_path(llama_path_max(), 0);
|
||||
llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split);
|
||||
fname = std::string(split_path.data());
|
||||
}
|
||||
|
||||
fout = std::ofstream(fname, std::ios::binary);
|
||||
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
|
||||
const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get());
|
||||
// placeholder for the meta data
|
||||
::zeros(fout, meta_size);
|
||||
};
|
||||
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
new_ofstream(0);
|
||||
for (const auto * it : tensors) {
|
||||
const auto & weight = *it;
|
||||
struct ggml_tensor * tensor = weight.tensor;
|
||||
if (weight.idx != cur_split && params->keep_split) {
|
||||
close_ofstream();
|
||||
new_ofstream(weight.idx);
|
||||
}
|
||||
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
if (!ml.use_mmap) {
|
||||
if (read_data.size() < ggml_nbytes(tensor)) {
|
||||
read_data.resize(ggml_nbytes(tensor));
|
||||
}
|
||||
tensor->data = read_data.data();
|
||||
}
|
||||
ml.load_data_for(tensor);
|
||||
|
||||
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
|
||||
++idx, ml.n_tensors,
|
||||
ggml_get_name(tensor),
|
||||
llama_format_tensor_shape(tensor).c_str(),
|
||||
ggml_type_name(tensor->type));
|
||||
|
||||
// This used to be a regex, but <regex> has an extreme cost to compile times.
|
||||
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
|
||||
|
||||
// quantize only 2D and 3D tensors (experts)
|
||||
quantize &= (ggml_n_dims(tensor) >= 2);
|
||||
|
||||
// do not quantize norm tensors
|
||||
quantize &= name.find("_norm.weight") == std::string::npos;
|
||||
|
||||
quantize &= params->quantize_output_tensor || name != "output.weight";
|
||||
quantize &= !params->only_copy;
|
||||
|
||||
// do not quantize expert gating tensors
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
|
||||
|
||||
// do not quantize positional embeddings and token types (BERT)
|
||||
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
|
||||
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
|
||||
|
||||
// do not quantize Mamba's small yet 2D weights
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||
|
||||
// do not quantize RWKV's time_mix_first tensors
|
||||
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
||||
enum ggml_type new_type;
|
||||
void * new_data;
|
||||
size_t new_size;
|
||||
|
||||
if (quantize) {
|
||||
new_type = default_type;
|
||||
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
if (!params->pure && ggml_is_quantized(default_type)) {
|
||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||
}
|
||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
new_type = params->token_embedding_type;
|
||||
}
|
||||
if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
|
||||
new_type = params->output_tensor_type;
|
||||
}
|
||||
|
||||
// If we've decided to quantize to the same type the tensor is already
|
||||
// in then there's nothing to do.
|
||||
quantize = tensor->type != new_type;
|
||||
}
|
||||
|
||||
if (!quantize) {
|
||||
new_type = tensor->type;
|
||||
new_data = tensor->data;
|
||||
new_size = ggml_nbytes(tensor);
|
||||
LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
} else {
|
||||
const int64_t nelements = ggml_nelements(tensor);
|
||||
|
||||
const float * imatrix = nullptr;
|
||||
if (imatrix_data) {
|
||||
auto it = imatrix_data->find(tensor->name);
|
||||
if (it == imatrix_data->end()) {
|
||||
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
||||
} else {
|
||||
if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
|
||||
imatrix = it->second.data();
|
||||
} else {
|
||||
LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
|
||||
int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
|
||||
|
||||
// this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
|
||||
// this is a significant error and it may be good idea to abort the process if this happens,
|
||||
// since many people will miss the error and not realize that most of the model is being quantized without an imatrix
|
||||
// tok_embd should be ignored in this case, since it always causes this warning
|
||||
if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
|
||||
throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
|
||||
int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((new_type == GGML_TYPE_IQ2_XXS ||
|
||||
new_type == GGML_TYPE_IQ2_XS ||
|
||||
new_type == GGML_TYPE_IQ2_S ||
|
||||
new_type == GGML_TYPE_IQ1_S ||
|
||||
(new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) ||
|
||||
(new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
|
||||
LLAMA_LOG_ERROR("\n\n============================================================\n");
|
||||
LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
|
||||
LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
|
||||
LLAMA_LOG_ERROR("============================================================\n\n");
|
||||
throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
|
||||
}
|
||||
|
||||
float * f32_data;
|
||||
|
||||
if (tensor->type == GGML_TYPE_F32) {
|
||||
f32_data = (float *) tensor->data;
|
||||
} else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
|
||||
throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
|
||||
} else {
|
||||
llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
|
||||
f32_data = (float *) f32_conv_buf.data();
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
|
||||
fflush(stdout);
|
||||
|
||||
if (work.size() < (size_t)nelements * 4) {
|
||||
work.resize(nelements * 4); // upper bound on size
|
||||
}
|
||||
new_data = work.data();
|
||||
|
||||
const int64_t n_per_row = tensor->ne[0];
|
||||
const int64_t nrows = tensor->ne[1];
|
||||
|
||||
static const int64_t min_chunk_size = 32 * 512;
|
||||
const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row));
|
||||
|
||||
const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
|
||||
const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
|
||||
const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
|
||||
|
||||
// quantize each expert separately since they have different importance matrices
|
||||
new_size = 0;
|
||||
for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
|
||||
const float * f32_data_03 = f32_data + i03 * nelements_matrix;
|
||||
void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
|
||||
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
|
||||
|
||||
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
|
||||
}
|
||||
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
||||
}
|
||||
total_size_org += ggml_nbytes(tensor);
|
||||
total_size_new += new_size;
|
||||
|
||||
// update the gguf meta data as we go
|
||||
gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
|
||||
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
|
||||
gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
|
||||
|
||||
// write tensor data + padding
|
||||
fout.write((const char *) new_data, new_size);
|
||||
zeros(fout, GGML_PAD(new_size, align) - new_size);
|
||||
}
|
||||
close_ofstream();
|
||||
|
||||
LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
|
||||
LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
|
||||
|
||||
if (qs.n_fallback > 0) {
|
||||
LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
|
||||
__func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_model_quantize_params llama_model_quantize_default_params() {
|
||||
struct llama_model_quantize_params result = {
|
||||
/*.nthread =*/ 0,
|
||||
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
||||
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
||||
/*.token_embedding_type =*/ GGML_TYPE_COUNT,
|
||||
/*.allow_requantize =*/ false,
|
||||
/*.quantize_output_tensor =*/ true,
|
||||
/*.only_copy =*/ false,
|
||||
/*.pure =*/ false,
|
||||
/*.keep_split =*/ false,
|
||||
/*.imatrix =*/ nullptr,
|
||||
/*.kv_overrides =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
const llama_model_quantize_params * params) {
|
||||
try {
|
||||
llama_model_quantize_impl(fname_inp, fname_out, params);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
1
examples/talk-llama/llama-quant.h
Normal file
1
examples/talk-llama/llama-quant.h
Normal file
@ -0,0 +1 @@
|
||||
#pragma once
|
@ -1,5 +1,6 @@
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
@ -14,6 +15,118 @@
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
#include <stdexcept>
|
||||
|
||||
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||
template<typename T>
|
||||
struct ring_buffer {
|
||||
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||
|
||||
T & front() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[first];
|
||||
}
|
||||
|
||||
const T & front() const {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[first];
|
||||
}
|
||||
|
||||
T & back() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[pos];
|
||||
}
|
||||
|
||||
const T & back() const {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
return data[pos];
|
||||
}
|
||||
|
||||
void push_back(const T & value) {
|
||||
if (capacity == 0) {
|
||||
throw std::runtime_error("ring buffer: capacity is zero");
|
||||
}
|
||||
|
||||
if (sz == capacity) {
|
||||
// advance the start when buffer is full
|
||||
first = (first + 1) % capacity;
|
||||
} else {
|
||||
sz++;
|
||||
}
|
||||
data[pos] = value;
|
||||
pos = (pos + 1) % capacity;
|
||||
}
|
||||
|
||||
T pop_front() {
|
||||
if (sz == 0) {
|
||||
throw std::runtime_error("ring buffer is empty");
|
||||
}
|
||||
T value = data[first];
|
||||
first = (first + 1) % capacity;
|
||||
sz--;
|
||||
return value;
|
||||
}
|
||||
|
||||
//T & operator[](size_t i) {
|
||||
// if (i >= sz) {
|
||||
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||
// }
|
||||
// return data[(first + i) % capacity];
|
||||
//}
|
||||
|
||||
//const T & at(size_t i) const {
|
||||
// if (i >= sz) {
|
||||
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||
// }
|
||||
// return data[(first + i) % capacity];
|
||||
//}
|
||||
|
||||
const T & rat(size_t i) const {
|
||||
if (i >= sz) {
|
||||
throw std::runtime_error("ring buffer: index out of bounds");
|
||||
}
|
||||
return data[(first + sz - i - 1) % capacity];
|
||||
}
|
||||
|
||||
std::vector<T> to_vector() const {
|
||||
std::vector<T> result;
|
||||
result.reserve(sz);
|
||||
for (size_t i = 0; i < sz; i++) {
|
||||
result.push_back(data[(first + i) % capacity]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
// here only reset the status of the buffer
|
||||
sz = 0;
|
||||
first = 0;
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return sz == 0;
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return sz;
|
||||
}
|
||||
|
||||
size_t capacity = 0;
|
||||
size_t sz = 0;
|
||||
size_t first = 0;
|
||||
size_t pos = 0;
|
||||
|
||||
std::vector<T> data;
|
||||
};
|
||||
|
||||
static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
|
||||
// iterator for the probabilities
|
||||
@ -144,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
||||
for (int i = 0; i < (int)cur_p->size; ++i) {
|
||||
const float val = cur_p->data[i].logit;
|
||||
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
||||
ib = std::max(0, std::min(nbuckets-1, ib));
|
||||
ib = std::max(0, std::min(nbuckets - 1, ib));
|
||||
bucket_idx[i] = ib;
|
||||
++histo[ib];
|
||||
}
|
||||
@ -167,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
||||
for (int i = 0; i < (int)cur_p->size; ++i) {
|
||||
int j = bucket_idx[i];
|
||||
if (j >= ib) {
|
||||
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
|
||||
*bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
|
||||
}
|
||||
}
|
||||
|
||||
ptr = tmp_tokens.data();
|
||||
int ndone = 0;
|
||||
for (int j = nbuckets-1; j > ib; --j) {
|
||||
for (int j = nbuckets - 1; j > ib; --j) {
|
||||
std::sort(ptr, ptr + histo[j], comp);
|
||||
ptr += histo[j];
|
||||
ndone += histo[j];
|
||||
@ -258,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
// TODO: do not allocate each time
|
||||
std::vector<llama_token_data> cur;
|
||||
@ -1317,13 +1433,30 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token
|
||||
}
|
||||
}
|
||||
|
||||
// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
|
||||
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||
if (!ctx->grammar) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
|
||||
std::vector<const char *> trigger_words;
|
||||
for (auto & word : ctx->grammar->trigger_words) {
|
||||
trigger_words.push_back(word.c_str());
|
||||
}
|
||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||
ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
|
||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||
|
||||
llama_grammar_free_impl(ctx->grammar);
|
||||
ctx->grammar = grammar_new;
|
||||
@ -1332,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
||||
|
||||
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
|
||||
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
|
||||
|
||||
// copy the state
|
||||
{
|
||||
@ -1368,19 +1501,27 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
|
||||
/* .free = */ llama_sampler_grammar_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
|
||||
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
bool lazy,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
auto * ctx = new llama_sampler_grammar;
|
||||
|
||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||
*ctx = {
|
||||
/* .vocab = */ &vocab,
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ grammar_str,
|
||||
/* .grammar_root = */ grammar_root,
|
||||
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
|
||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
|
||||
};
|
||||
} else {
|
||||
*ctx = {
|
||||
/* .vocab = */ &vocab,
|
||||
/* .vocab = */ vocab,
|
||||
/* .grammar_str = */ {},
|
||||
/* .grammar_root = */ {},
|
||||
/* .grammar = */ nullptr,
|
||||
@ -1393,6 +1534,24 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root) {
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens) {
|
||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
|
||||
}
|
||||
|
||||
// penalties
|
||||
|
||||
struct llama_sampler_penalties {
|
||||
@ -1550,8 +1709,8 @@ struct llama_sampler_dry {
|
||||
|
||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
||||
std::string word = llama_detokenize(vocab, {token_id}, true);
|
||||
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
|
||||
std::string word = vocab.detokenize({token_id}, true);
|
||||
if (word.find(str) != std::string::npos) {
|
||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||
} else {
|
||||
@ -1568,7 +1727,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
|
||||
std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
|
||||
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
|
||||
tokenization.resize(max_tail_len);
|
||||
}
|
||||
@ -1719,7 +1878,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
|
||||
ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
|
||||
if (n > 0) {
|
||||
lt = k;
|
||||
rt = k+n-1;
|
||||
rt = k + n - 1;
|
||||
}
|
||||
} else {
|
||||
// If k is inside the current Z-box, consider two cases.
|
||||
@ -1824,7 +1983,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
|
||||
llama_vocab dummy_vocab;
|
||||
|
||||
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
|
||||
auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||
auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||
|
||||
// Copy the state, including the processed breakers
|
||||
{
|
||||
@ -1851,7 +2010,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
|
||||
/* .free = */ llama_sampler_dry_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||
const int MAX_CHAR_LEN = 40;
|
||||
@ -1878,7 +2037,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
||||
sequence_break.resize(MAX_CHAR_LEN);
|
||||
}
|
||||
|
||||
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
||||
get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1901,7 +2060,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
||||
// wrapper for test-sampling.cpp
|
||||
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
||||
llama_vocab dummy_vocab;
|
||||
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||
auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||
auto * ctx = (llama_sampler_dry *) result->ctx;
|
||||
|
||||
// Process the token-based sequence breakers
|
||||
@ -2040,7 +2199,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
float p_eog_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||
if (ctx->vocab->is_eog(cur_p->data[i].id)) {
|
||||
p_eog_sum += cur_p->data[i].p;
|
||||
} else {
|
||||
p_txt_sum += cur_p->data[i].p;
|
||||
@ -2062,7 +2221,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
float p_sum = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
|
||||
if (ctx->vocab->is_eog(cur_p->data[i].id)) {
|
||||
p_sum += cur_p->data[i].p;
|
||||
|
||||
cur_p->data[cur_p->size++] = cur_p->data[i];
|
||||
@ -2090,17 +2249,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
continue;
|
||||
}
|
||||
|
||||
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
if (len0 < 0) {
|
||||
ctx->buf0.resize(len0);
|
||||
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
|
||||
assert(len0 > 0);
|
||||
}
|
||||
|
||||
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
if (len1 < 0) {
|
||||
ctx->buf1.resize(len1);
|
||||
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
|
||||
assert(len1 > 0);
|
||||
}
|
||||
|
||||
@ -2135,7 +2294,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||
const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
|
||||
|
||||
if (cur_p->data[i].p < thold && !is_eog) {
|
||||
continue;
|
||||
@ -2156,7 +2315,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
|
||||
if (n_non_eog == 0) {
|
||||
cur_p->size = 1;
|
||||
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
|
||||
cur_p->data[0].id = ctx->vocab->token_eot();
|
||||
cur_p->data[0].logit = 1.0f;
|
||||
|
||||
return;
|
||||
@ -2178,7 +2337,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
|
||||
|
||||
for (size_t i = 0; i < size_org; ++i) {
|
||||
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
|
||||
const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
|
||||
|
||||
if (cur_p->data[i].p < thold && !is_eog) {
|
||||
continue;
|
||||
@ -2201,7 +2360,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
|
||||
|
||||
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
|
||||
return llama_sampler_init_infill_impl(*ctx->vocab);
|
||||
return llama_sampler_init_infill(ctx->vocab);
|
||||
}
|
||||
|
||||
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
||||
@ -2217,14 +2376,13 @@ static struct llama_sampler_i llama_sampler_infill_i = {
|
||||
/* .free = */ llama_sampler_infill_free,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||
const struct llama_vocab & vocab) {
|
||||
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ &llama_sampler_infill_i,
|
||||
/* .ctx = */ new llama_sampler_infill {
|
||||
/* .vocab = */ &vocab,
|
||||
/* .buf0 = */ std::vector<char>(512),
|
||||
/* .buf1 = */ std::vector<char>(512),
|
||||
/* .vocab = */ vocab,
|
||||
/* .buf0 = */ std::vector<char>(512),
|
||||
/* .buf1 = */ std::vector<char>(512),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
|
||||
|
||||
#include "llama-grammar.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
struct llama_vocab;
|
||||
struct llama_grammar;
|
||||
@ -21,24 +23,6 @@ struct llama_sampler_chain {
|
||||
mutable int32_t n_sample;
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill_impl(
|
||||
const struct llama_vocab & vocab);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
int32_t context_size,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
int32_t dry_penalty_last_n,
|
||||
const char ** seq_breakers,
|
||||
size_t num_breakers);
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_testing(
|
||||
int32_t context_size,
|
||||
float dry_multiplier,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,170 +1,125 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
struct llm_tokenizer;
|
||||
struct LLM_KV;
|
||||
struct llama_model_loader;
|
||||
|
||||
struct llama_vocab {
|
||||
using id = llama_token;
|
||||
using token = std::string;
|
||||
using tattr = llama_token_attr;
|
||||
|
||||
struct token_data {
|
||||
token text;
|
||||
float score;
|
||||
tattr attr;
|
||||
std::string text;
|
||||
float score;
|
||||
llama_token_attr attr;
|
||||
};
|
||||
|
||||
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
|
||||
|
||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
||||
int max_token_len = 0; // used for optimizing longest token search
|
||||
|
||||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
|
||||
std::vector<id> cache_special_tokens;
|
||||
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
|
||||
|
||||
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
||||
|
||||
// default LLaMA special tokens
|
||||
// TODO: should we set all of these to LLAMA_TOKEN_NULL?
|
||||
id special_bos_id = 1;
|
||||
id special_eos_id = 2;
|
||||
id special_eot_id = LLAMA_TOKEN_NULL;
|
||||
id special_eom_id = LLAMA_TOKEN_NULL;
|
||||
id special_unk_id = 0;
|
||||
id special_sep_id = LLAMA_TOKEN_NULL;
|
||||
id special_pad_id = LLAMA_TOKEN_NULL;
|
||||
id special_cls_id = LLAMA_TOKEN_NULL;
|
||||
id special_mask_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
id linefeed_id = 13;
|
||||
|
||||
// fim tokens
|
||||
id special_fim_pre_id = LLAMA_TOKEN_NULL;
|
||||
id special_fim_suf_id = LLAMA_TOKEN_NULL;
|
||||
id special_fim_mid_id = LLAMA_TOKEN_NULL;
|
||||
id special_fim_pad_id = LLAMA_TOKEN_NULL;
|
||||
id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
|
||||
id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
|
||||
|
||||
// set of all tokens that cause "end of generation"
|
||||
std::set<id> special_eog_ids;
|
||||
|
||||
// tokenizer flags
|
||||
bool tokenizer_add_space_prefix = false;
|
||||
bool tokenizer_add_bos = false;
|
||||
bool tokenizer_add_eos = false;
|
||||
bool tokenizer_ignore_merges = false;
|
||||
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
|
||||
bool tokenizer_remove_extra_whitespaces = false;
|
||||
bool tokenizer_escape_whitespaces = true;
|
||||
bool tokenizer_treat_whitespace_as_suffix = false;
|
||||
|
||||
std::vector<char> precompiled_charsmap;
|
||||
|
||||
llm_tokenizer * tokenizer = nullptr;
|
||||
|
||||
llama_vocab() = default;
|
||||
llama_vocab();
|
||||
~llama_vocab();
|
||||
|
||||
void load(llama_model_loader & ml, const LLM_KV & kv);
|
||||
|
||||
enum llama_vocab_type get_type() const;
|
||||
enum llama_vocab_pre_type get_pre_type() const;
|
||||
|
||||
uint32_t n_tokens() const;
|
||||
uint32_t n_token_types() const;
|
||||
|
||||
std::string type_name() const;
|
||||
|
||||
bool is_normal (llama_token id) const;
|
||||
bool is_unknown (llama_token id) const;
|
||||
bool is_control (llama_token id) const;
|
||||
bool is_byte (llama_token id) const;
|
||||
bool is_user_defined(llama_token id) const;
|
||||
bool is_unused (llama_token id) const;
|
||||
bool is_eog (llama_token id) const;
|
||||
|
||||
uint8_t token_to_byte(llama_token id) const;
|
||||
llama_token byte_to_token(uint8_t ch) const;
|
||||
|
||||
llama_token text_to_token(const std::string & text) const;
|
||||
|
||||
const token_data & get_token_data(llama_token id) const;
|
||||
|
||||
const char * token_get_text (llama_token id) const;
|
||||
float token_get_score(llama_token id) const;
|
||||
llama_token_attr token_get_attr (llama_token id) const;
|
||||
|
||||
llama_token token_bos() const;
|
||||
llama_token token_eos() const;
|
||||
llama_token token_eot() const;
|
||||
llama_token token_eom() const;
|
||||
llama_token token_unk() const;
|
||||
llama_token token_sep() const;
|
||||
llama_token token_nl () const;
|
||||
llama_token token_pad() const;
|
||||
|
||||
llama_token token_prefix() const;
|
||||
llama_token token_middle() const;
|
||||
llama_token token_suffix() const;
|
||||
|
||||
llama_token token_fim_pre() const;
|
||||
llama_token token_fim_suf() const;
|
||||
llama_token token_fim_mid() const;
|
||||
llama_token token_fim_pad() const;
|
||||
llama_token token_fim_rep() const;
|
||||
llama_token token_fim_sep() const;
|
||||
|
||||
bool get_add_space_prefix () const;
|
||||
bool get_add_bos () const;
|
||||
bool get_add_eos () const;
|
||||
bool get_ignore_merges () const;
|
||||
bool get_clean_spaces () const;
|
||||
bool get_remove_extra_whitespaces () const;
|
||||
bool get_escape_whitespaces () const;
|
||||
bool get_treat_whitespace_as_suffix() const;
|
||||
|
||||
int max_token_len() const;
|
||||
|
||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||
|
||||
void init_tokenizer();
|
||||
int32_t tokenize(
|
||||
const char * text,
|
||||
int32_t text_len,
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens_max,
|
||||
bool add_special,
|
||||
bool parse_special) const;
|
||||
|
||||
std::vector<llama_token> tokenize(
|
||||
const std::string & raw_text,
|
||||
bool add_special,
|
||||
bool parse_special = false) const;
|
||||
|
||||
// does not write null-terminator to buf
|
||||
int32_t token_to_piece(
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int32_t length,
|
||||
int32_t lstrip,
|
||||
bool special) const;
|
||||
|
||||
// use cached data
|
||||
const std::string & token_to_piece(llama_token token) const;
|
||||
|
||||
int32_t detokenize(
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
char * text,
|
||||
int32_t text_len_max,
|
||||
bool remove_special,
|
||||
bool unparse_special) const;
|
||||
|
||||
std::string detokenize(
|
||||
const std::vector<llama_token> & tokens,
|
||||
bool special) const;
|
||||
|
||||
void print_info() const;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
//
|
||||
// internal API
|
||||
//
|
||||
|
||||
// TODO: rename to llama_tokenize_impl
|
||||
// TODO: This should probably be in llama.h
|
||||
std::vector<llama_vocab::id> llama_tokenize_internal(
|
||||
const llama_vocab & vocab,
|
||||
std::string raw_text,
|
||||
bool add_special,
|
||||
bool parse_special = false);
|
||||
|
||||
// TODO: move the API below as member functions of llama_vocab
|
||||
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
||||
|
||||
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
||||
|
||||
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
|
||||
|
||||
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
|
||||
|
||||
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
|
||||
|
||||
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
|
||||
|
||||
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
|
||||
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
|
||||
|
||||
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
|
||||
|
||||
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
|
||||
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
|
||||
|
||||
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
|
||||
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
|
||||
|
||||
int32_t llama_tokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const char * text,
|
||||
int32_t text_len,
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens_max,
|
||||
bool add_special,
|
||||
bool parse_special);
|
||||
|
||||
// does not write null-terminator to buf
|
||||
int32_t llama_token_to_piece_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int32_t length,
|
||||
int32_t lstrip,
|
||||
bool special);
|
||||
|
||||
// check if token0 is contained as a prefix in token1
|
||||
bool llama_token_is_prefix_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
llama_token token0,
|
||||
llama_token token1);
|
||||
|
||||
int32_t llama_detokenize_impl(
|
||||
const struct llama_vocab & vocab,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
char * text,
|
||||
int32_t text_len_max,
|
||||
bool remove_special,
|
||||
bool unparse_special);
|
||||
|
||||
std::string llama_detokenize(
|
||||
const struct llama_vocab & vocab,
|
||||
const std::vector<llama_token> & tokens,
|
||||
bool special);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -34,7 +34,6 @@
|
||||
|
||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
|
||||
// TODO: use everywhere in the implementation
|
||||
#define LLAMA_TOKEN_NULL -1
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||
@ -57,7 +56,7 @@ extern "C" {
|
||||
// TODO: show sample usage
|
||||
//
|
||||
|
||||
// struct llama_vocab; // TODO: add in the future
|
||||
struct llama_vocab;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampler;
|
||||
@ -105,6 +104,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@ -288,9 +288,6 @@ extern "C" {
|
||||
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
||||
const float * tensor_split;
|
||||
|
||||
// comma separated list of RPC servers to use for offloading
|
||||
const char * rpc_servers;
|
||||
|
||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||
// If the provided progress_callback returns true, model loading continues.
|
||||
// If it returns false, model loading is immediately aborted.
|
||||
@ -385,7 +382,7 @@ extern "C" {
|
||||
} llama_chat_message;
|
||||
|
||||
// lora adapter
|
||||
struct llama_lora_adapter;
|
||||
struct llama_adapter_lora;
|
||||
|
||||
// Helpers for getting default parameters
|
||||
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
||||
@ -399,30 +396,53 @@ extern "C" {
|
||||
// Call once at the start of the program
|
||||
LLAMA_API void llama_backend_init(void);
|
||||
|
||||
// Call once at the end of the program - currently only used for MPI
|
||||
LLAMA_API void llama_backend_free(void);
|
||||
|
||||
//optional:
|
||||
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
|
||||
|
||||
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
||||
LLAMA_API void llama_attach_threadpool(
|
||||
struct llama_context * ctx,
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
struct llama_context * ctx,
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
|
||||
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||
|
||||
// Call once at the end of the program - currently only used for MPI
|
||||
LLAMA_API void llama_backend_free(void);
|
||||
DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params),
|
||||
"use llama_model_load_from_file instead");
|
||||
|
||||
LLAMA_API struct llama_model * llama_load_model_from_file(
|
||||
// Load the model from a file
|
||||
// If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf
|
||||
// If the split file name does not follow this pattern, use llama_model_load_from_splits
|
||||
LLAMA_API struct llama_model * llama_model_load_from_file(
|
||||
const char * path_model,
|
||||
struct llama_model_params params);
|
||||
|
||||
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||
// Load the model from multiple splits (support custom naming scheme)
|
||||
// The paths must be in the correct order
|
||||
LLAMA_API struct llama_model * llama_model_load_from_splits(
|
||||
const char ** paths,
|
||||
size_t n_paths,
|
||||
struct llama_model_params params);
|
||||
|
||||
// TODO: rename to llama_init_from_model
|
||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
|
||||
"use llama_model_free instead");
|
||||
|
||||
LLAMA_API void llama_model_free(struct llama_model * model);
|
||||
|
||||
LLAMA_API struct llama_context * llama_init_from_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params);
|
||||
|
||||
DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params),
|
||||
"use llama_init_from_model instead");
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
@ -440,20 +460,30 @@ extern "C" {
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead");
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead");
|
||||
|
||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
||||
|
||||
// Functions to access the model's GGUF metadata scalar values
|
||||
// - The functions return the length of the string on success, or -1 on failure
|
||||
@ -479,12 +509,13 @@ extern "C" {
|
||||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
|
||||
// Get the default chat template. Returns nullptr if not available
|
||||
// If name is NULL, returns the default chat template
|
||||
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
|
||||
|
||||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
// Get a llama model tensor
|
||||
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
|
||||
|
||||
// Returns true if the model contains an encoder that requires llama_encode() call
|
||||
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
|
||||
|
||||
@ -504,32 +535,36 @@ extern "C" {
|
||||
const char * fname_out,
|
||||
const llama_model_quantize_params * params);
|
||||
|
||||
//
|
||||
// Adapters
|
||||
//
|
||||
|
||||
// Load a LoRA adapter from file
|
||||
// The loaded adapter will be associated to the given model, and will be free when the model is deleted
|
||||
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
|
||||
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
|
||||
struct llama_model * model,
|
||||
const char * path_lora);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// Note: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||
|
||||
// The following functions operate on a llama_context, hence the naming: llama_verb_...
|
||||
|
||||
// Add a loaded LoRA adapter to given context
|
||||
// This will not modify model's weight
|
||||
LLAMA_API int32_t llama_lora_adapter_set(
|
||||
LLAMA_API int32_t llama_set_adapter_lora(
|
||||
struct llama_context * ctx,
|
||||
struct llama_lora_adapter * adapter,
|
||||
struct llama_adapter_lora * adapter,
|
||||
float scale);
|
||||
|
||||
// Remove a specific LoRA adapter from given context
|
||||
// Return -1 if the adapter is not present in the context
|
||||
LLAMA_API int32_t llama_lora_adapter_remove(
|
||||
LLAMA_API int32_t llama_rm_adapter_lora(
|
||||
struct llama_context * ctx,
|
||||
struct llama_lora_adapter * adapter);
|
||||
struct llama_adapter_lora * adapter);
|
||||
|
||||
// Remove all LoRA adapters from given context
|
||||
LLAMA_API void llama_lora_adapter_clear(
|
||||
struct llama_context * ctx);
|
||||
|
||||
// Manually free a LoRA adapter
|
||||
// Note: loaded adapters will be free when the associated model is deleted
|
||||
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
|
||||
LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
|
||||
|
||||
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
||||
// the currently loaded vector.
|
||||
@ -537,8 +572,8 @@ extern "C" {
|
||||
// to an n_embd x n_layers buffer starting from layer 1.
|
||||
// il_start and il_end are the layer range the vector should apply to (both inclusive)
|
||||
// See llama_control_vector_load in common to load a control vector.
|
||||
LLAMA_API int32_t llama_control_vector_apply(
|
||||
struct llama_context * lctx,
|
||||
LLAMA_API int32_t llama_apply_adapter_cvec(
|
||||
struct llama_context * ctx,
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
@ -549,6 +584,8 @@ extern "C" {
|
||||
// KV cache
|
||||
//
|
||||
|
||||
// TODO: remove llama_kv_cache_view_* API
|
||||
|
||||
// Information associated with an individual cell in the KV cache view.
|
||||
struct llama_kv_cache_view_cell {
|
||||
// The position for this cell. Takes KV cache shifts into account.
|
||||
@ -595,8 +632,11 @@ extern "C" {
|
||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
||||
|
||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||
// TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx)
|
||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||
|
||||
///
|
||||
|
||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
@ -666,6 +706,9 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
|
||||
// how to avoid this?
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
@ -886,41 +929,60 @@ extern "C" {
|
||||
// Vocab
|
||||
//
|
||||
|
||||
LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
|
||||
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
// Identify if Token Id is a control token or a render-able token
|
||||
LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
|
||||
LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token);
|
||||
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
||||
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
|
||||
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
|
||||
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
|
||||
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
|
||||
LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence
|
||||
LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn
|
||||
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
|
||||
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
|
||||
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
|
||||
|
||||
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
|
||||
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
|
||||
// infill tokens
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
|
||||
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
|
||||
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
|
||||
DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead");
|
||||
DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
|
||||
DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
|
||||
DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead");
|
||||
DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead");
|
||||
DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead");
|
||||
DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
|
||||
DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");
|
||||
|
||||
// CLS is equivalent to BOS
|
||||
DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification
|
||||
"use llama_vocab_bos instead");
|
||||
|
||||
//
|
||||
// Tokenization
|
||||
@ -936,7 +998,7 @@ extern "C" {
|
||||
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||
/// as plaintext. Does not insert a leading space.
|
||||
LLAMA_API int32_t llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
const char * text,
|
||||
int32_t text_len,
|
||||
llama_token * tokens,
|
||||
@ -950,7 +1012,7 @@ extern "C" {
|
||||
// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
|
||||
// @param special If true, special tokens are rendered in the output.
|
||||
LLAMA_API int32_t llama_token_to_piece(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int32_t length,
|
||||
@ -964,7 +1026,7 @@ extern "C" {
|
||||
/// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
|
||||
/// @param unparse_special If true, special tokens are rendered in the output.
|
||||
LLAMA_API int32_t llama_detokenize(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
char * text,
|
||||
@ -987,7 +1049,6 @@ extern "C" {
|
||||
/// @param length The size of the allocated buffer
|
||||
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
|
||||
LLAMA_API int32_t llama_chat_apply_template(
|
||||
const struct llama_model * model,
|
||||
const char * tmpl,
|
||||
const struct llama_chat_message * chat,
|
||||
size_t n_msg,
|
||||
@ -1035,7 +1096,6 @@ extern "C" {
|
||||
// llama_sampler_free(smpl);
|
||||
//
|
||||
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
|
||||
//
|
||||
|
||||
typedef void * llama_sampler_context_t;
|
||||
@ -1135,10 +1195,22 @@ extern "C" {
|
||||
float eta);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_model * model,
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root);
|
||||
|
||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639
|
||||
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
const char * grammar_root,
|
||||
const char ** trigger_words,
|
||||
size_t num_trigger_words,
|
||||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
@ -1147,8 +1219,9 @@ extern "C" {
|
||||
float penalty_present); // 0.0 = disabled
|
||||
|
||||
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||
const struct llama_model * model,
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
|
||||
const struct llama_vocab * vocab,
|
||||
int32_t n_ctx_train,
|
||||
float dry_multiplier,
|
||||
float dry_base,
|
||||
int32_t dry_allowed_length,
|
||||
@ -1182,7 +1255,7 @@ extern "C" {
|
||||
// 3. discard non-EOG tokens with low prob
|
||||
// 4. if no tokens are left -> pick EOT
|
||||
//
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
|
||||
|
||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||
|
@ -15,17 +15,24 @@
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
static std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
auto * model = llama_get_model(ctx);
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, false);
|
||||
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_bos, false);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, false);
|
||||
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_bos, false);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
@ -34,11 +41,14 @@ static std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const
|
||||
}
|
||||
|
||||
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
std::vector<char> result(8, 0);
|
||||
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), 0, false);
|
||||
const int n_tokens = llama_token_to_piece(vocab, token, result.data(), result.size(), 0, false);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), 0, false);
|
||||
int check = llama_token_to_piece(vocab, token, result.data(), result.size(), 0, false);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
@ -268,6 +278,10 @@ The transcript only includes text, it does not include markup like HTML and Mark
|
||||
{0}{4})";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
#if defined(_WIN32)
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
#endif
|
||||
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
@ -304,12 +318,14 @@ int main(int argc, char ** argv) {
|
||||
lmparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
|
||||
struct llama_model * model_llama = llama_model_load_from_file(params.model_llama.c_str(), lmparams);
|
||||
if (!model_llama) {
|
||||
fprintf(stderr, "No llama.cpp model specified. Please provide using -ml <modelfile>\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab_llama = llama_model_get_vocab(model_llama);
|
||||
|
||||
llama_context_params lcparams = llama_context_default_params();
|
||||
|
||||
// tune these to your liking
|
||||
@ -317,7 +333,7 @@ int main(int argc, char ** argv) {
|
||||
lcparams.n_threads = params.n_threads;
|
||||
lcparams.flash_attn = params.flash_attn;
|
||||
|
||||
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
|
||||
struct llama_context * ctx_llama = llama_init_from_model(model_llama, lcparams);
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
@ -727,7 +743,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const llama_token id = llama_sampler_sample(smpl, ctx_llama, -1);
|
||||
|
||||
if (id != llama_token_eos(model_llama)) {
|
||||
if (id != llama_vocab_eos(vocab_llama)) {
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
|
||||
|
@ -7,18 +7,17 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <codecvt>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <locale>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
|
||||
size_t unicode_len_utf8(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
@ -667,18 +666,24 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||
{ "\\p{L}", unicode_cpt_flags::LETTER },
|
||||
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
|
||||
{ "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
|
||||
{ "\\p{S}", unicode_cpt_flags::SYMBOL },
|
||||
};
|
||||
|
||||
static const std::map<int, int> k_ucat_cpt = {
|
||||
{ unicode_cpt_flags::NUMBER, 0xD1 },
|
||||
{ unicode_cpt_flags::LETTER, 0xD2 },
|
||||
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
|
||||
{ unicode_cpt_flags::ACCENT_MARK, 0xD4 },
|
||||
{ unicode_cpt_flags::SYMBOL, 0xD5 },
|
||||
};
|
||||
|
||||
static const std::map<int, std::string> k_ucat_map = {
|
||||
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||
{ unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
|
||||
{ unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
|
||||
};
|
||||
|
||||
// compute collapsed codepoints only if needed by at least one regex
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "grammar-parser.h"
|
||||
#include "common.h"
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
WChess::WChess(whisper_context * ctx,
|
||||
const whisper_full_params & wparams,
|
||||
|
@ -7,7 +7,6 @@
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
18133C802C64E342005CEAAC /* ggml-aarch64.c in Sources */ = {isa = PBXBuildFile; fileRef = 18133C7F2C64E342005CEAAC /* ggml-aarch64.c */; };
|
||||
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; };
|
||||
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
|
||||
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
|
||||
@ -27,9 +26,11 @@
|
||||
18E864A92CE73C1E0094B8B3 /* ggml-cpu.c in Sources */ = {isa = PBXBuildFile; fileRef = 18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */; };
|
||||
18F8C0BC2CEDF4DC00CAD607 /* ggml-threading.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */; };
|
||||
18F8C0BE2CEDF50700CAD607 /* ggml-cpu.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */; };
|
||||
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.c in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.c */; };
|
||||
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.cpp */; settings = {COMPILER_FLAGS = "-x c++"; }; };
|
||||
18F8C0C52CEDF52700CAD607 /* ggml-cpu-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */; };
|
||||
18F8C0C72CEDF7AB00CAD607 /* ggml-backend-reg.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */; };
|
||||
433188B82D3A187C00E3FE79 /* gguf.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 433188B72D3A187C00E3FE79 /* gguf.cpp */; };
|
||||
437B63E22D36280C002A49EC /* ggml-cpu-traits.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 437B63E12D36280C002A49EC /* ggml-cpu-traits.cpp */; };
|
||||
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
|
||||
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
|
||||
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
|
||||
@ -51,8 +52,6 @@
|
||||
/* End PBXCopyFilesBuildPhase section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
18133C7E2C64E342005CEAAC /* ggml-aarch64.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-aarch64.h"; path = "../../../ggml/src/ggml-aarch64.h"; sourceTree = "<group>"; };
|
||||
18133C7F2C64E342005CEAAC /* ggml-aarch64.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-aarch64.c"; path = "../../../ggml/src/ggml-aarch64.c"; sourceTree = "<group>"; };
|
||||
184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml/src/ggml-alloc.c"; sourceTree = "<group>"; };
|
||||
184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml/include/ggml-alloc.h"; sourceTree = "<group>"; };
|
||||
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml/src/ggml-metal/ggml-metal.m"; sourceTree = "<group>"; };
|
||||
@ -88,11 +87,15 @@
|
||||
18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-threading.cpp"; path = "../../../ggml/src/ggml-threading.cpp"; sourceTree = "<group>"; };
|
||||
18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-cpu.cpp"; path = "../../../ggml/src/ggml-cpu/ggml-cpu.cpp"; sourceTree = "<group>"; };
|
||||
18F8C0BF2CEDF52700CAD607 /* ggml-cpu-aarch64.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-aarch64.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.h"; sourceTree = "<group>"; };
|
||||
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-aarch64.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.c"; sourceTree = "<group>"; };
|
||||
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-aarch64.cpp"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp"; sourceTree = "<group>"; };
|
||||
18F8C0C12CEDF52700CAD607 /* ggml-cpu-impl.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-impl.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-impl.h"; sourceTree = "<group>"; };
|
||||
18F8C0C22CEDF52700CAD607 /* ggml-cpu-quants.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-quants.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-quants.h"; sourceTree = "<group>"; };
|
||||
18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-quants.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-quants.c"; sourceTree = "<group>"; };
|
||||
18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-backend-reg.cpp"; path = "../../../ggml/src/ggml-backend-reg.cpp"; sourceTree = "<group>"; };
|
||||
433188B72D3A187C00E3FE79 /* gguf.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = gguf.cpp; path = ../../../ggml/src/gguf.cpp; sourceTree = "<group>"; };
|
||||
433188B92D3A18A400E3FE79 /* gguf.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = gguf.h; path = ../../../ggml/include/gguf.h; sourceTree = "<group>"; };
|
||||
437B63E02D36280C002A49EC /* ggml-cpu-traits.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-traits.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-traits.h"; sourceTree = "<group>"; };
|
||||
437B63E12D36280C002A49EC /* ggml-cpu-traits.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-cpu-traits.cpp"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-traits.cpp"; sourceTree = "<group>"; };
|
||||
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
|
||||
7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
|
||||
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
|
||||
@ -132,10 +135,14 @@
|
||||
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
433188B92D3A18A400E3FE79 /* gguf.h */,
|
||||
433188B72D3A187C00E3FE79 /* gguf.cpp */,
|
||||
18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */,
|
||||
18F8C0BF2CEDF52700CAD607 /* ggml-cpu-aarch64.h */,
|
||||
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.c */,
|
||||
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.cpp */,
|
||||
18F8C0C12CEDF52700CAD607 /* ggml-cpu-impl.h */,
|
||||
437B63E02D36280C002A49EC /* ggml-cpu-traits.h */,
|
||||
437B63E12D36280C002A49EC /* ggml-cpu-traits.cpp */,
|
||||
18F8C0C22CEDF52700CAD607 /* ggml-cpu-quants.h */,
|
||||
18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */,
|
||||
18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */,
|
||||
@ -143,8 +150,6 @@
|
||||
18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */,
|
||||
18E864AA2CE73C580094B8B3 /* ggml-cpu.h */,
|
||||
18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */,
|
||||
18133C7F2C64E342005CEAAC /* ggml-aarch64.c */,
|
||||
18133C7E2C64E342005CEAAC /* ggml-aarch64.h */,
|
||||
18A275FF2C2A9563001C8D37 /* ggml-common.h */,
|
||||
18A275FE2C2A94DE001C8D37 /* ggml-metal.h */,
|
||||
18ABE1562AF556340044A204 /* ggml-backend-impl.h */,
|
||||
@ -269,21 +274,22 @@
|
||||
files = (
|
||||
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */,
|
||||
18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */,
|
||||
18133C802C64E342005CEAAC /* ggml-aarch64.c in Sources */,
|
||||
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */,
|
||||
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */,
|
||||
437B63E22D36280C002A49EC /* ggml-cpu-traits.cpp in Sources */,
|
||||
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
|
||||
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
|
||||
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
|
||||
18F8C0C72CEDF7AB00CAD607 /* ggml-backend-reg.cpp in Sources */,
|
||||
18F8C0BE2CEDF50700CAD607 /* ggml-cpu.cpp in Sources */,
|
||||
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
|
||||
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.c in Sources */,
|
||||
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.cpp in Sources */,
|
||||
18F8C0C52CEDF52700CAD607 /* ggml-cpu-quants.c in Sources */,
|
||||
18E864A92CE73C1E0094B8B3 /* ggml-cpu.c in Sources */,
|
||||
18ABE15A2AF556340044A204 /* ggml-backend.cpp in Sources */,
|
||||
18627C8C29052BE000BD2A04 /* main.m in Sources */,
|
||||
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
|
||||
433188B82D3A187C00E3FE79 /* gguf.cpp in Sources */,
|
||||
18F8C0BC2CEDF4DC00CAD607 /* ggml-threading.cpp in Sources */,
|
||||
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
|
||||
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
|
||||
|
@ -58,7 +58,8 @@ else()
|
||||
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
|
||||
endif()
|
||||
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
|
||||
message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
|
||||
set(GGML_NATIVE_DEFAULT OFF)
|
||||
else()
|
||||
set(GGML_NATIVE_DEFAULT ON)
|
||||
@ -74,10 +75,10 @@ if (NOT GGML_CUDA_GRAPHS_DEFAULT)
|
||||
endif()
|
||||
|
||||
# general
|
||||
option(GGML_STATIC "ggml: static link libraries" OFF)
|
||||
option(GGML_NATIVE "ggml: enable -march=native flag" ${GGML_NATIVE_DEFAULT})
|
||||
option(GGML_LTO "ggml: enable link time optimization" OFF)
|
||||
option(GGML_CCACHE "ggml: use ccache if available" ON)
|
||||
option(GGML_STATIC "ggml: static link libraries" OFF)
|
||||
option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT})
|
||||
option(GGML_LTO "ggml: enable link time optimization" OFF)
|
||||
option(GGML_CCACHE "ggml: use ccache if available" ON)
|
||||
|
||||
# debug
|
||||
option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
|
||||
@ -120,8 +121,9 @@ endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_SVE "ggml: enable SVE" OFF)
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
|
||||
|
||||
if (WIN32)
|
||||
@ -152,6 +154,8 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
option(GGML_HIP "ggml: use HIP" OFF)
|
||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||
@ -184,6 +188,9 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas
|
||||
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
|
||||
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
# extra artifacts
|
||||
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
|
||||
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
|
||||
@ -242,7 +249,8 @@ set(GGML_PUBLIC_HEADERS
|
||||
include/ggml-metal.h
|
||||
include/ggml-rpc.h
|
||||
include/ggml-sycl.h
|
||||
include/ggml-vulkan.h)
|
||||
include/ggml-vulkan.h
|
||||
include/gguf.h)
|
||||
|
||||
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||
#if (GGML_METAL)
|
||||
@ -251,26 +259,6 @@ set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
|
||||
install(TARGETS ggml-base LIBRARY)
|
||||
|
||||
# FIXME: this should be done in the backend cmake files
|
||||
if (GGML_METAL)
|
||||
# FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
|
||||
install(
|
||||
FILES src/ggml-metal/ggml-metal.metal
|
||||
PERMISSIONS
|
||||
OWNER_READ
|
||||
OWNER_WRITE
|
||||
GROUP_READ
|
||||
WORLD_READ
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
|
||||
if (NOT GGML_METAL_EMBED_LIBRARY)
|
||||
install(
|
||||
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_STANDALONE)
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
||||
@ -279,3 +267,77 @@ if (GGML_STANDALONE)
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
||||
DESTINATION share/pkgconfig)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Create CMake package
|
||||
#
|
||||
|
||||
# Generate version info based on git commit.
|
||||
|
||||
if(NOT DEFINED GGML_BUILD_NUMBER)
|
||||
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
|
||||
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_NUMBER
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(GGML_BUILD_NUMBER EQUAL 1)
|
||||
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Capture variables prefixed with GGML_.
|
||||
|
||||
set(variable_set_statements
|
||||
"
|
||||
####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
|
||||
####### Any changes to this file will be overwritten by the next CMake run #######
|
||||
|
||||
")
|
||||
|
||||
set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
|
||||
|
||||
get_cmake_property(all_variables VARIABLES)
|
||||
foreach(variable_name IN LISTS all_variables)
|
||||
if(variable_name MATCHES "^GGML_")
|
||||
string(REPLACE ";" "\\;"
|
||||
variable_value "${${variable_name}}")
|
||||
|
||||
set(variable_set_statements
|
||||
"${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
|
||||
|
||||
# Create the CMake package and set install location.
|
||||
|
||||
set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER})
|
||||
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
|
||||
PATH_VARS GGML_INCLUDE_INSTALL_DIR
|
||||
GGML_LIB_INSTALL_DIR
|
||||
GGML_BIN_INSTALL_DIR)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||
VERSION ${GGML_INSTALL_VERSION}
|
||||
COMPATIBILITY SameMajorVersion)
|
||||
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
||||
|
54
ggml/cmake/BuildTypes.cmake
Normal file
54
ggml/cmake/BuildTypes.cmake
Normal file
@ -0,0 +1,54 @@
|
||||
# Add new build types
|
||||
|
||||
# ReleaseGG - Release with enabled asserts
|
||||
|
||||
SET(CMAKE_CXX_FLAGS_RELEASEGG
|
||||
"-O3"
|
||||
CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_C_FLAGS_RELEASEGG
|
||||
"-O3"
|
||||
CACHE STRING "Flags used by the compiler during release builds with enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG
|
||||
""
|
||||
CACHE STRING "Flags used for linking binaries during release builds with enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG
|
||||
""
|
||||
CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts."
|
||||
FORCE )
|
||||
MARK_AS_ADVANCED(
|
||||
CMAKE_CXX_FLAGS_RELEASEGG
|
||||
CMAKE_C_FLAGS_RELEASEGG
|
||||
CMAKE_EXE_LINKER_FLAGS_RELEASEGG
|
||||
CMAKE_SHARED_LINKER_FLAGS_RELEASEGG )
|
||||
|
||||
# RelWithDebInfoGG - RelWithDebInfo with enabled asserts
|
||||
|
||||
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
|
||||
"-O2 -g"
|
||||
CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG
|
||||
"-O2 -g"
|
||||
CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
|
||||
""
|
||||
CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG
|
||||
""
|
||||
CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts."
|
||||
FORCE )
|
||||
MARK_AS_ADVANCED(
|
||||
CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
|
||||
CMAKE_C_FLAGS_RELWITHDEBINFOGG
|
||||
CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
|
||||
CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG )
|
||||
|
||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG")
|
||||
endif()
|
22
ggml/cmake/GitVars.cmake
Normal file
22
ggml/cmake/GitVars.cmake
Normal file
@ -0,0 +1,22 @@
|
||||
find_package(Git)
|
||||
|
||||
# the commit's SHA1
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_SHA1
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
# the date of the commit
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_DATE
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
# the subject of the commit
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" log -1 --format=%s
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
147
ggml/cmake/ggml-config.cmake.in
Normal file
147
ggml/cmake/ggml-config.cmake.in
Normal file
@ -0,0 +1,147 @@
|
||||
|
||||
@GGML_VARIABLES_EXPANDED@
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||
set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
find_library(GGML_LIBRARY ggml
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
||||
|
||||
find_library(GGML_BASE_LIBRARY ggml-base
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml-base
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
||||
|
||||
if (NOT GGML_SHARED_LIB)
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
find_library(memkind memkind REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
|
||||
endif()
|
||||
|
||||
if (GGML_BLAS)
|
||||
find_package(BLAS REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
|
||||
list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
|
||||
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP)
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
find_package(DNNL)
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
|
||||
endif()
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(_ggml_all_targets "")
|
||||
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
||||
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
||||
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
||||
|
||||
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
||||
|
||||
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
||||
if(is_cpu_variant)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
||||
endforeach()
|
||||
|
||||
add_library(ggml::all INTERFACE IMPORTED)
|
||||
set_target_properties(ggml::all
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
||||
|
||||
check_required_components(ggml)
|
@ -203,6 +203,8 @@ extern "C" {
|
||||
// Backend registry
|
||||
//
|
||||
|
||||
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
|
||||
|
||||
// Backend (reg) enumeration
|
||||
GGML_API size_t ggml_backend_reg_count(void);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "gguf.h"
|
||||
#include <memory>
|
||||
|
||||
// Smart pointers for ggml types
|
||||
|
@ -241,12 +241,6 @@
|
||||
#define GGML_ROPE_TYPE_MROPE 8
|
||||
#define GGML_ROPE_TYPE_VISION 24
|
||||
|
||||
#define GGUF_MAGIC "GGUF"
|
||||
|
||||
#define GGUF_VERSION 3
|
||||
|
||||
#define GGUF_DEFAULT_ALIGNMENT 32
|
||||
|
||||
#define GGML_UNUSED(x) (void)(x)
|
||||
|
||||
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
||||
@ -403,12 +397,6 @@ extern "C" {
|
||||
GGML_PREC_F32,
|
||||
};
|
||||
|
||||
enum ggml_backend_type {
|
||||
GGML_BACKEND_TYPE_CPU = 0,
|
||||
GGML_BACKEND_TYPE_GPU = 10,
|
||||
GGML_BACKEND_TYPE_GPU_SPLIT = 20,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum ggml_ftype {
|
||||
GGML_FTYPE_UNKNOWN = -1,
|
||||
@ -513,6 +501,7 @@ extern "C" {
|
||||
GGML_OP_GET_REL_POS,
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
@ -587,8 +576,6 @@ extern "C" {
|
||||
struct ggml_tensor {
|
||||
enum ggml_type type;
|
||||
|
||||
GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
|
||||
|
||||
struct ggml_backend_buffer * buffer;
|
||||
|
||||
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
||||
@ -1397,16 +1384,20 @@ extern "C" {
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
// rotary position embedding
|
||||
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)
|
||||
@ -1513,7 +1504,7 @@ extern "C" {
|
||||
|
||||
// rotary position embedding backward, i.e compute dx from dy
|
||||
// a - dy
|
||||
GGML_API struct ggml_tensor * ggml_rope_back(
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // gradients of ggml_rope result
|
||||
struct ggml_tensor * b, // positions
|
||||
@ -1528,6 +1519,23 @@ extern "C" {
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rope_multi_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int sections[4],
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
|
||||
// clamp
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_clamp(
|
||||
@ -1564,17 +1572,6 @@ extern "C" {
|
||||
int d1, // dilation dimension 1
|
||||
bool is_2D);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
@ -1592,6 +1589,23 @@ extern "C" {
|
||||
int s, // stride
|
||||
int d); // dilation
|
||||
|
||||
// depthwise
|
||||
// TODO: this is very likely wrong for some cases! - needs more testing
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d_dw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride
|
||||
int p0, // padding
|
||||
int d0); // dilation
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride
|
||||
int d0); // dilation
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
@ -1611,7 +1625,6 @@ extern "C" {
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
|
||||
// kernel size is a->ne[0] x a->ne[1]
|
||||
// stride is equal to kernel size
|
||||
// padding is zero
|
||||
@ -1638,6 +1651,18 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// depthwise
|
||||
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -1750,7 +1775,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
int k);
|
||||
|
||||
#define GGML_KQ_MASK_PAD 32
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd, n_batch, n_head, 1]
|
||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||
@ -1856,6 +1881,15 @@ extern "C" {
|
||||
struct ggml_tensor * td,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_gated_linear_attn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * state,
|
||||
float scale);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
@ -2094,132 +2128,6 @@ extern "C" {
|
||||
int64_t n_per_row,
|
||||
const float * imatrix);
|
||||
|
||||
//
|
||||
// gguf
|
||||
//
|
||||
|
||||
enum gguf_type {
|
||||
GGUF_TYPE_UINT8 = 0,
|
||||
GGUF_TYPE_INT8 = 1,
|
||||
GGUF_TYPE_UINT16 = 2,
|
||||
GGUF_TYPE_INT16 = 3,
|
||||
GGUF_TYPE_UINT32 = 4,
|
||||
GGUF_TYPE_INT32 = 5,
|
||||
GGUF_TYPE_FLOAT32 = 6,
|
||||
GGUF_TYPE_BOOL = 7,
|
||||
GGUF_TYPE_STRING = 8,
|
||||
GGUF_TYPE_ARRAY = 9,
|
||||
GGUF_TYPE_UINT64 = 10,
|
||||
GGUF_TYPE_INT64 = 11,
|
||||
GGUF_TYPE_FLOAT64 = 12,
|
||||
GGUF_TYPE_COUNT, // marks the end of the enum
|
||||
};
|
||||
|
||||
struct gguf_context;
|
||||
|
||||
struct gguf_init_params {
|
||||
bool no_alloc;
|
||||
|
||||
// if not NULL, create a ggml_context and allocate the tensor data in it
|
||||
struct ggml_context ** ctx;
|
||||
};
|
||||
|
||||
GGML_API struct gguf_context * gguf_init_empty(void);
|
||||
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
||||
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
||||
|
||||
GGML_API void gguf_free(struct gguf_context * ctx);
|
||||
|
||||
GGML_API const char * gguf_type_name(enum gguf_type type);
|
||||
|
||||
GGML_API int gguf_get_version (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
||||
GGML_API void * gguf_get_data (const struct gguf_context * ctx);
|
||||
|
||||
GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
|
||||
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
|
||||
|
||||
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
|
||||
|
||||
// will abort if the wrong type is used for the key
|
||||
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
||||
|
||||
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
|
||||
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
|
||||
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
|
||||
GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i);
|
||||
|
||||
// removes key if it exists
|
||||
GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
|
||||
|
||||
// overrides existing values or adds a new one
|
||||
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
||||
GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
|
||||
GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
|
||||
GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
|
||||
GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
|
||||
GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
|
||||
GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
|
||||
GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
|
||||
GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
|
||||
GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
|
||||
GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
|
||||
GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
|
||||
GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
|
||||
GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
|
||||
|
||||
// set or add KV pairs from another context
|
||||
GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
|
||||
|
||||
// manage tensor info
|
||||
GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
|
||||
GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
|
||||
GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
|
||||
|
||||
// writing gguf files can be done in 2 ways:
|
||||
//
|
||||
// - write the entire gguf_context to a binary file in a single pass:
|
||||
//
|
||||
// gguf_write_to_file(ctx, fname);
|
||||
//
|
||||
// - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
|
||||
//
|
||||
// FILE * f = fopen(fname, "wb");
|
||||
// fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
|
||||
// fwrite(f, ...);
|
||||
// void * data = gguf_meta_get_meta_data(ctx);
|
||||
// fseek(f, 0, SEEK_SET);
|
||||
// fwrite(f, data, gguf_get_meta_size(ctx));
|
||||
// free(data);
|
||||
// fclose(f);
|
||||
//
|
||||
|
||||
// write the entire context to a binary file
|
||||
GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
|
||||
|
||||
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
||||
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
||||
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
// restrict not standard in C++
|
||||
# if defined(__GNUC__)
|
||||
|
202
ggml/include/gguf.h
Normal file
202
ggml/include/gguf.h
Normal file
@ -0,0 +1,202 @@
|
||||
// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
|
||||
// GGUF files have the following structure:
|
||||
//
|
||||
// 1. File magic "GGUF" (4 bytes).
|
||||
// 2. File version (uint32_t).
|
||||
// 3. Number of ggml tensors in file (int64_t).
|
||||
// 4. Number of key-value-pairs in file (int64_t).
|
||||
// 5. For each KV pair:
|
||||
// 1. The key (string).
|
||||
// 2. The value type (gguf_type).
|
||||
// 3a. If the value type is GGUF_TYPE_ARRAY:
|
||||
// 1. The type of the array (gguf_type).
|
||||
// 2. The number of elements in the array (uint64_t).
|
||||
// 3. The binary representation of each element in the array.
|
||||
// 3b. Otherwise:
|
||||
// 1. The binary representation of the value.
|
||||
// 6. For each ggml tensor:
|
||||
// 1. The tensor name (string).
|
||||
// 2. The number of dimensions of the tensor (uint32_t).
|
||||
// 3. For each dimension:
|
||||
// 1. The size of the tensor in the dimension (int64_t).
|
||||
// 4. The tensor data type (ggml_type).
|
||||
// 5. The tensor data offset in the tensor data binary blob (uint64_t).
|
||||
// 7. The tensor data binary blob (optional, aligned).
|
||||
//
|
||||
// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
|
||||
// All enums are stored as int32_t.
|
||||
// All bool values are stored as int8_t.
|
||||
// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
|
||||
// otherwise GGUF_DEFAULT_ALIGNMENT is used.
|
||||
//
|
||||
// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define GGUF_MAGIC "GGUF"
|
||||
#define GGUF_VERSION 3
|
||||
|
||||
#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
|
||||
|
||||
#define GGUF_DEFAULT_ALIGNMENT 32
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// types that can be stored as GGUF KV data
|
||||
enum gguf_type {
|
||||
GGUF_TYPE_UINT8 = 0,
|
||||
GGUF_TYPE_INT8 = 1,
|
||||
GGUF_TYPE_UINT16 = 2,
|
||||
GGUF_TYPE_INT16 = 3,
|
||||
GGUF_TYPE_UINT32 = 4,
|
||||
GGUF_TYPE_INT32 = 5,
|
||||
GGUF_TYPE_FLOAT32 = 6,
|
||||
GGUF_TYPE_BOOL = 7,
|
||||
GGUF_TYPE_STRING = 8,
|
||||
GGUF_TYPE_ARRAY = 9,
|
||||
GGUF_TYPE_UINT64 = 10,
|
||||
GGUF_TYPE_INT64 = 11,
|
||||
GGUF_TYPE_FLOAT64 = 12,
|
||||
GGUF_TYPE_COUNT, // marks the end of the enum
|
||||
};
|
||||
|
||||
struct gguf_context;
|
||||
|
||||
struct gguf_init_params {
|
||||
bool no_alloc;
|
||||
|
||||
// if not NULL, create a ggml_context and allocate the tensor data in it
|
||||
struct ggml_context ** ctx;
|
||||
};
|
||||
|
||||
GGML_API struct gguf_context * gguf_init_empty(void);
|
||||
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
||||
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
||||
|
||||
GGML_API void gguf_free(struct gguf_context * ctx);
|
||||
|
||||
GGML_API const char * gguf_type_name(enum gguf_type type);
|
||||
|
||||
GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
||||
|
||||
GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx);
|
||||
GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
|
||||
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
|
||||
|
||||
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
|
||||
|
||||
// will abort if the wrong type is used for the key
|
||||
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
|
||||
GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id);
|
||||
|
||||
// get raw pointer to the first element of the array with the given key_id
|
||||
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
||||
|
||||
// get ith C string from array with given key_id
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||
|
||||
GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx);
|
||||
GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
|
||||
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
|
||||
GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id);
|
||||
GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id);
|
||||
GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id);
|
||||
|
||||
// removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
|
||||
GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
|
||||
|
||||
// overrides an existing KV pair or adds a new one, the new KV pair is always at the back
|
||||
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
||||
GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
|
||||
GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
|
||||
GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
|
||||
GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
|
||||
GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
|
||||
GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
|
||||
GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
|
||||
GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
|
||||
GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
|
||||
GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
|
||||
GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
|
||||
|
||||
// creates a new array with n elements of the given type and copies the corresponding number of bytes from data
|
||||
GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
|
||||
|
||||
// creates a new array with n strings and copies the corresponding strings from data
|
||||
GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
|
||||
|
||||
// set or add KV pairs from another context
|
||||
GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
|
||||
|
||||
// add tensor to GGUF context, tensor name must be unique
|
||||
GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
|
||||
|
||||
// after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
|
||||
// in such a way that the tensor data remains as one contiguous block (except for padding)
|
||||
GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
|
||||
|
||||
// assumes that at least gguf_get_tensor_size bytes can be read from data
|
||||
GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
|
||||
|
||||
// writing gguf files can be done in 3 ways:
|
||||
//
|
||||
// - write the entire gguf_context to a binary file in a single pass:
|
||||
//
|
||||
// gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
|
||||
//
|
||||
// - write only the meta data to a file, then re-open the file and append the tensor data:
|
||||
//
|
||||
// gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
|
||||
// FILE * f = fopen(fname, "ab");
|
||||
// fwrite(f, ...); // write tensor data
|
||||
// fclose(f);
|
||||
//
|
||||
// - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
|
||||
//
|
||||
// FILE * f = fopen(fname, "wb");
|
||||
// const size_t size_meta = gguf_get_meta_size(ctx);
|
||||
// fseek(f, size_meta, SEEK_SET);
|
||||
// fwrite(f, ...); // write tensor data
|
||||
// void * data = malloc(size_meta);
|
||||
// gguf_get_meta_data(ctx, data);
|
||||
// rewind(f);
|
||||
// fwrite(data, 1, data, f);
|
||||
// free(data);
|
||||
// fclose(f);
|
||||
//
|
||||
|
||||
// write the entire context to a binary file
|
||||
GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
|
||||
|
||||
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
||||
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
||||
|
||||
// writes the meta data to pointer "data"
|
||||
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
@ -93,12 +93,18 @@ endif()
|
||||
|
||||
if (GGML_CCACHE)
|
||||
find_program(GGML_CCACHE_FOUND ccache)
|
||||
find_program(GGML_SCCACHE_FOUND sccache)
|
||||
|
||||
if (GGML_CCACHE_FOUND)
|
||||
if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
|
||||
if(GGML_CCACHE_FOUND)
|
||||
set(GGML_CCACHE_VARIANT ccache)
|
||||
else()
|
||||
set(GGML_CCACHE_VARIANT sccache)
|
||||
endif()
|
||||
# TODO: should not be set globally
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
|
||||
set(ENV{CCACHE_SLOPPINESS} time_macros)
|
||||
message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
||||
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
||||
else()
|
||||
message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
|
||||
endif ()
|
||||
@ -208,6 +214,7 @@ add_library(ggml-base
|
||||
../include/ggml-backend.h
|
||||
../include/ggml-cpp.h
|
||||
../include/ggml-opt.h
|
||||
../include/gguf.h
|
||||
ggml.c
|
||||
ggml-alloc.c
|
||||
ggml-backend.cpp
|
||||
@ -215,7 +222,8 @@ add_library(ggml-base
|
||||
ggml-threading.cpp
|
||||
ggml-threading.h
|
||||
ggml-quants.c
|
||||
ggml-quants.h)
|
||||
ggml-quants.h
|
||||
gguf.cpp)
|
||||
|
||||
target_include_directories(ggml-base PRIVATE .)
|
||||
|
||||
@ -234,6 +242,7 @@ function(ggml_add_backend_library backend)
|
||||
# write the shared library to the output directory
|
||||
set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
|
||||
add_dependencies(ggml ${backend})
|
||||
else()
|
||||
add_library(${backend} ${ARGN})
|
||||
target_link_libraries(ggml PUBLIC ${backend})
|
||||
@ -247,6 +256,17 @@ function(ggml_add_backend_library backend)
|
||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
|
||||
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
||||
endif()
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
else()
|
||||
list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
|
||||
if(has_backend EQUAL -1)
|
||||
set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(ggml_add_backend backend)
|
||||
@ -289,12 +309,12 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
|
||||
if (NOT MSVC)
|
||||
# MSVC doesn't support AVX-VNNI or AMX
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
else ()
|
||||
elseif (GGML_CPU)
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
|
||||
|
@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
|
||||
return true;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_SCALE:
|
||||
@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
return true;
|
||||
|
||||
default:
|
||||
|
@ -208,7 +208,6 @@ extern "C" {
|
||||
|
||||
// Internal backend registry API
|
||||
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
|
||||
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
|
||||
|
||||
// Add backend dynamic loading support to the backend
|
||||
|
||||
|
@ -66,6 +66,26 @@
|
||||
#include "ggml-kompute.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
static std::wstring utf8_to_utf16(const std::string & str) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return converter.from_bytes(str);
|
||||
}
|
||||
|
||||
static std::string utf16_to_utf8(const std::wstring & str) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return converter.to_bytes(str);
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
@ -88,11 +108,6 @@ static dl_handle * dl_load_library(const std::wstring & path) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
static dl_handle * dl_load_library(const std::string & path) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return dl_load_library(converter.from_bytes(path));
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
@ -114,8 +129,8 @@ struct dl_handle_deleter {
|
||||
}
|
||||
};
|
||||
|
||||
static void * dl_load_library(const std::string & path) {
|
||||
dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
static void * dl_load_library(const std::wstring & path) {
|
||||
dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
|
||||
return handle;
|
||||
}
|
||||
@ -202,11 +217,11 @@ struct ggml_backend_registry {
|
||||
devices.push_back(device);
|
||||
}
|
||||
|
||||
ggml_backend_reg_t load_backend(const char * path, bool silent) {
|
||||
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
|
||||
dl_handle_ptr handle { dl_load_library(path) };
|
||||
if (!handle) {
|
||||
if (!silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path);
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -214,7 +229,7 @@ struct ggml_backend_registry {
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (score_fn && score_fn() == 0) {
|
||||
if (!silent) {
|
||||
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path);
|
||||
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -222,7 +237,7 @@ struct ggml_backend_registry {
|
||||
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
|
||||
if (!backend_init_fn) {
|
||||
if (!silent) {
|
||||
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path);
|
||||
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -231,16 +246,16 @@ struct ggml_backend_registry {
|
||||
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
|
||||
if (!silent) {
|
||||
if (!reg) {
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path);
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
|
||||
__func__, path, reg->api_version, GGML_BACKEND_API_VERSION);
|
||||
__func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path);
|
||||
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
|
||||
|
||||
register_backend(reg, std::move(handle));
|
||||
|
||||
@ -376,14 +391,14 @@ ggml_backend_t ggml_backend_init_best(void) {
|
||||
|
||||
// Dynamic loading
|
||||
ggml_backend_reg_t ggml_backend_load(const char * path) {
|
||||
return get_reg().load_backend(path, false);
|
||||
return get_reg().load_backend(utf8_to_utf16(path), false);
|
||||
}
|
||||
|
||||
void ggml_backend_unload(ggml_backend_reg_t reg) {
|
||||
get_reg().unload_backend(reg, true);
|
||||
}
|
||||
|
||||
static std::string get_executable_path() {
|
||||
static std::wstring get_executable_path() {
|
||||
#if defined(__APPLE__)
|
||||
// get executable path
|
||||
std::vector<char> path;
|
||||
@ -401,13 +416,17 @@ static std::string get_executable_path() {
|
||||
if (last_slash != std::string::npos) {
|
||||
base_path = base_path.substr(0, last_slash);
|
||||
}
|
||||
return base_path + "/";
|
||||
#elif defined(__linux__)
|
||||
return utf8_to_utf16(base_path + "/");
|
||||
#elif defined(__linux__) || defined(__FreeBSD__)
|
||||
std::string base_path = ".";
|
||||
std::vector<char> path(1024);
|
||||
while (true) {
|
||||
// get executable path
|
||||
# if defined(__linux__)
|
||||
ssize_t len = readlink("/proc/self/exe", path.data(), path.size());
|
||||
# elif defined(__FreeBSD__)
|
||||
ssize_t len = readlink("/proc/curproc/file", path.data(), path.size());
|
||||
# endif
|
||||
if (len == -1) {
|
||||
break;
|
||||
}
|
||||
@ -423,57 +442,63 @@ static std::string get_executable_path() {
|
||||
path.resize(path.size() * 2);
|
||||
}
|
||||
|
||||
return base_path + "/";
|
||||
return utf8_to_utf16(base_path + "/");
|
||||
#elif defined(_WIN32)
|
||||
std::vector<char> path(MAX_PATH);
|
||||
DWORD len = GetModuleFileNameA(NULL, path.data(), path.size());
|
||||
std::vector<wchar_t> path(MAX_PATH);
|
||||
DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
|
||||
if (len == 0) {
|
||||
return "";
|
||||
return {};
|
||||
}
|
||||
std::string base_path(path.data(), len);
|
||||
std::wstring base_path(path.data(), len);
|
||||
// remove executable name
|
||||
auto last_slash = base_path.find_last_of('\\');
|
||||
if (last_slash != std::string::npos) {
|
||||
base_path = base_path.substr(0, last_slash);
|
||||
}
|
||||
return base_path + "\\";
|
||||
return base_path + L"\\";
|
||||
#else
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::string backend_filename_prefix() {
|
||||
static std::wstring backend_filename_prefix() {
|
||||
#ifdef _WIN32
|
||||
return "ggml-";
|
||||
return L"ggml-";
|
||||
#else
|
||||
return "libggml-";
|
||||
return L"libggml-";
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::string backend_filename_suffix() {
|
||||
static std::wstring backend_filename_suffix() {
|
||||
#ifdef _WIN32
|
||||
return ".dll";
|
||||
return L".dll";
|
||||
#else
|
||||
return ".so";
|
||||
return L".so";
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::wstring path_separator() {
|
||||
#ifdef _WIN32
|
||||
return L"\\";
|
||||
#else
|
||||
return L"/";
|
||||
#endif
|
||||
}
|
||||
|
||||
static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
// enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
|
||||
// TODO: search system paths
|
||||
std::string file_prefix = backend_filename_prefix() + name + "-";
|
||||
std::vector<std::string> search_paths;
|
||||
std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
|
||||
std::vector<std::wstring> search_paths;
|
||||
if (user_search_path == nullptr) {
|
||||
search_paths.push_back("./");
|
||||
search_paths.push_back(L"." + path_separator());
|
||||
search_paths.push_back(get_executable_path());
|
||||
} else {
|
||||
#if defined(_WIN32)
|
||||
search_paths.push_back(std::string(user_search_path) + "\\");
|
||||
#else
|
||||
search_paths.push_back(std::string(user_search_path) + "/");
|
||||
#endif
|
||||
search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
|
||||
}
|
||||
|
||||
int best_score = 0;
|
||||
std::string best_path;
|
||||
std::wstring best_path;
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
for (const auto & search_path : search_paths) {
|
||||
@ -483,27 +508,27 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
std::string ext = entry.path().extension().string();
|
||||
std::wstring filename = entry.path().filename().wstring();
|
||||
std::wstring ext = entry.path().extension().wstring();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path().c_str()) };
|
||||
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
if (!handle && !silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str());
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
}
|
||||
if (handle) {
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (score_fn) {
|
||||
int s = score_fn();
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s);
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
#endif
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path().string();
|
||||
best_path = entry.path().wstring();
|
||||
}
|
||||
} else {
|
||||
if (!silent) {
|
||||
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str());
|
||||
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -515,15 +540,15 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
if (best_score == 0) {
|
||||
// try to load the base backend
|
||||
for (const auto & search_path : search_paths) {
|
||||
std::string path = search_path + backend_filename_prefix() + name + backend_filename_suffix();
|
||||
std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
|
||||
if (fs::exists(path)) {
|
||||
return get_reg().load_backend(path.c_str(), silent);
|
||||
return get_reg().load_backend(path, silent);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return get_reg().load_backend(best_path.c_str(), silent);
|
||||
return get_reg().load_backend(best_path, silent);
|
||||
}
|
||||
|
||||
void ggml_backend_load_all() {
|
||||
@ -549,4 +574,9 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
ggml_backend_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
||||
if (backend_path) {
|
||||
ggml_backend_load(backend_path);
|
||||
}
|
||||
}
|
||||
|
@ -764,7 +764,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
||||
if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
|
||||
// check if a backend with higher prio wants to offload the op
|
||||
if (src_backend_id == sched->n_backends - 1) {
|
||||
if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
|
||||
for (int b = 0; b < src_backend_id; b++) {
|
||||
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
|
||||
SET_CAUSE(tensor, "1.off");
|
||||
@ -795,9 +795,12 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
||||
for (int i = 0; i < graph->n_nodes; i++) {
|
||||
if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
|
||||
ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
|
||||
GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
|
||||
GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, ggml_backend_name(split_backend),
|
||||
sched->splits[cur_split].n_inputs);
|
||||
for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
|
||||
if (j == 0) {
|
||||
GGML_LOG_DEBUG(": ");
|
||||
}
|
||||
GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
|
||||
fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user