mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-11 00:56:54 +02:00
Compare commits
155 Commits
v1.7.5
...
sync-ggml-
Author | SHA1 | Date | |
---|---|---|---|
0055356fbc | |||
eeaa1cd035 | |||
a652c8bf72 | |||
0630539c8a | |||
a7988d76db | |||
37ac0264ef | |||
5a9ccde7da | |||
cde0e50536 | |||
df458380d6 | |||
87b88ed01c | |||
9b584b0cc0 | |||
09846f4e12 | |||
bcf1ed0163 | |||
934d4b3083 | |||
988dcd4b5b | |||
9f540ad8cb | |||
1fa17bc752 | |||
366082d072 | |||
0778b6ff5f | |||
5cd59c9396 | |||
d052e64d42 | |||
780750a108 | |||
919c78e618 | |||
dc288f84cd | |||
1543a3600c | |||
4872355f6e | |||
1a76e97c28 | |||
7017c1d37d | |||
670bf02662 | |||
9fff2f751c | |||
46392f733f | |||
eeb259909e | |||
fe21ddf0dc | |||
33bdbfbb33 | |||
0f49edf0f3 | |||
25efcfe3ed | |||
edbd4cb7f5 | |||
3ae9b8416a | |||
55d73a13f5 | |||
2e30e6df59 | |||
f0171f0616 | |||
b7db9e7aac | |||
f3c42399a3 | |||
28dcdff4c5 | |||
50218b935d | |||
f9b2dfdd8c | |||
50fda73f4c | |||
1c20f46887 | |||
adaea088bc | |||
6c0d843f9d | |||
efb800557f | |||
337becefb9 | |||
11ae30c19e | |||
88c3cecd43 | |||
fe4acb33e3 | |||
fd5a3e1bc6 | |||
01e1600edd | |||
cf3eb291ab | |||
3d54b68ea7 | |||
11218294db | |||
33c89ade7d | |||
27a56e7243 | |||
f4ca3e2f9c | |||
0287a5c51b | |||
24d29c55df | |||
36019c35a3 | |||
4e936e2afa | |||
314ce5981e | |||
cb7642b0f5 | |||
7db8f278f0 | |||
be42a19eab | |||
b8755670ca | |||
483eecae62 | |||
43e3d25d93 | |||
e1dbf9a42e | |||
ee0013865d | |||
32a407166b | |||
622f981853 | |||
d049d67065 | |||
877308838e | |||
d87dfcf7c0 | |||
915c14ef10 | |||
5d33d3c929 | |||
751e42b21e | |||
e8ee32d12d | |||
e9ce285135 | |||
b942f451b6 | |||
e6410faf99 | |||
182df69384 | |||
3bf9691dfd | |||
ba444e9c23 | |||
c6caf8eef2 | |||
6cae79a1d7 | |||
b9bfe0c693 | |||
1d50c6ac22 | |||
79f23d9132 | |||
ee2cbeeb74 | |||
868a5ce310 | |||
b9c71fae5a | |||
6d67c6d93d | |||
12cade118e | |||
fd1c725e65 | |||
d33fd00cfe | |||
3e0d89782a | |||
7074b622eb | |||
b8d3e45342 | |||
1901505138 | |||
3c26dd3353 | |||
d792d2a2dc | |||
8add58aa5e | |||
8f8ede1b12 | |||
3a6fe8d767 | |||
76231bda56 | |||
785437c253 | |||
2f0612cb1c | |||
e944065d5b | |||
ccc7b5df0b | |||
fbed36851e | |||
d1d847f184 | |||
337f91d4a6 | |||
317a0031f9 | |||
b243416918 | |||
6e532c7187 | |||
2105b110d3 | |||
f82622180f | |||
a71c64512a | |||
1e9c2f87f1 | |||
06ce8f83e6 | |||
8b92060a10 | |||
7858eddd10 | |||
3a88f1e504 | |||
f0d2bfbfb7 | |||
170b2faf75 | |||
f8a3509b6d | |||
2a2d21c75d | |||
9cfcd6cc45 | |||
e853620270 | |||
549db9376f | |||
33a25e4dda | |||
43f5030aeb | |||
cf794133de | |||
ef6cf357e7 | |||
b1f5c11b32 | |||
ada745f4a5 | |||
01985c22c0 | |||
448f3d3b93 | |||
e6234cd435 | |||
2b6d0d2200 | |||
0b17d4507e | |||
77e0c86ab6 | |||
eac1bc9c47 | |||
cbde66d913 | |||
513ecf8dc0 | |||
cce5daf17b | |||
2c502b3c00 |
@ -13,8 +13,6 @@ WORKDIR /app
|
|||||||
ARG CUDA_DOCKER_ARCH=all
|
ARG CUDA_DOCKER_ARCH=all
|
||||||
# Set nvcc architecture
|
# Set nvcc architecture
|
||||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||||
# Enable cuBLAS
|
|
||||||
ENV GGML_CUDA=1
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
||||||
@ -25,7 +23,8 @@ ENV CUDA_MAIN_VERSION=12.3
|
|||||||
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
COPY .. .
|
COPY .. .
|
||||||
RUN make base.en
|
# Enable cuBLAS
|
||||||
|
RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1"
|
||||||
|
|
||||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
|
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
|
||||||
ENV CUDA_MAIN_VERSION=12.3
|
ENV CUDA_MAIN_VERSION=12.3
|
||||||
@ -37,4 +36,5 @@ RUN apt-get update && \
|
|||||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||||
|
|
||||||
COPY --from=build /app /app
|
COPY --from=build /app /app
|
||||||
|
ENV PATH=/app/build/bin:$PATH
|
||||||
ENTRYPOINT [ "bash", "-c" ]
|
ENTRYPOINT [ "bash", "-c" ]
|
||||||
|
29
.devops/main-musa.Dockerfile
Normal file
29
.devops/main-musa.Dockerfile
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
ARG UBUNTU_VERSION=22.04
|
||||||
|
# This needs to generally match the container host's environment.
|
||||||
|
ARG MUSA_VERSION=rc3.1.1
|
||||||
|
# Target the MUSA build image
|
||||||
|
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||||
|
# Target the MUSA runtime image
|
||||||
|
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||||
|
|
||||||
|
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||||
|
|
||||||
|
COPY .. .
|
||||||
|
# Enable muBLAS
|
||||||
|
RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1"
|
||||||
|
|
||||||
|
FROM ${BASE_MUSA_RUN_CONTAINER} AS runtime
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y curl ffmpeg wget cmake git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||||
|
|
||||||
|
COPY --from=build /app /app
|
||||||
|
ENV PATH=/app/build/bin:$PATH
|
||||||
|
ENTRYPOINT [ "bash", "-c" ]
|
@ -16,4 +16,5 @@ RUN apt-get update && \
|
|||||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||||
|
|
||||||
COPY --from=build /app /app
|
COPY --from=build /app /app
|
||||||
|
ENV PATH=/app/build/bin:$PATH
|
||||||
ENTRYPOINT [ "bash", "-c" ]
|
ENTRYPOINT [ "bash", "-c" ]
|
||||||
|
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
build*/
|
||||||
|
.github/
|
||||||
|
.devops/
|
54
.github/workflows/bindings-ruby.yml
vendored
54
.github/workflows/bindings-ruby.yml
vendored
@ -1,55 +1,11 @@
|
|||||||
name: Bindings Tests (Ruby)
|
name: Bindings Tests (Ruby)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
paths:
|
branches:
|
||||||
- bindings/ruby/**
|
- master
|
||||||
- src/**/*.c
|
|
||||||
- src/**/*.cpp
|
|
||||||
- src/**/*.h
|
|
||||||
- src/**/*.m
|
|
||||||
- src/**/*.metal
|
|
||||||
- include/**/*.c
|
|
||||||
- include/**/*.cpp
|
|
||||||
- include/**/*.h
|
|
||||||
- include/**/*.m
|
|
||||||
- include/**/*.metal
|
|
||||||
- ggml/**/*.c
|
|
||||||
- ggml/**/*.cpp
|
|
||||||
- ggml/**/*.h
|
|
||||||
- ggml/**/*.m
|
|
||||||
- ggml/**/*.metal
|
|
||||||
- scripts/get-flags.mk
|
|
||||||
- examples/common.h
|
|
||||||
- examples/common.cpp
|
|
||||||
- examples/common-whisper.h
|
|
||||||
- examples/common-whisper.cpp
|
|
||||||
- examples/stb_vorbis.c
|
|
||||||
- examples/miniaudio.h
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
types: [opened, synchronize, reopened]
|
||||||
- bindings/ruby/**
|
|
||||||
- src/**/*.c
|
|
||||||
- src/**/*.cpp
|
|
||||||
- src/**/*.h
|
|
||||||
- src/**/*.m
|
|
||||||
- src/**/*.metal
|
|
||||||
- include/**/*.c
|
|
||||||
- include/**/*.cpp
|
|
||||||
- include/**/*.h
|
|
||||||
- include/**/*.m
|
|
||||||
- include/**/*.metal
|
|
||||||
- ggml/**/*.c
|
|
||||||
- ggml/**/*.cpp
|
|
||||||
- ggml/**/*.h
|
|
||||||
- ggml/**/*.m
|
|
||||||
- ggml/**/*.metal
|
|
||||||
- scripts/get-flags.mk
|
|
||||||
- examples/common.h
|
|
||||||
- examples/common.cpp
|
|
||||||
- examples/common-whisper.h
|
|
||||||
- examples/common-whisper.cpp
|
|
||||||
- examples/stb_vorbis.c
|
|
||||||
- examples/miniaudio.h
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ubuntu-22:
|
ubuntu-22:
|
||||||
@ -60,6 +16,6 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: ruby/setup-ruby@v1
|
- uses: ruby/setup-ruby@v1
|
||||||
with:
|
with:
|
||||||
ruby-version: '3.1'
|
ruby-version: '3.2'
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- run: rake test
|
- run: rake test
|
||||||
|
180
.github/workflows/build.yml
vendored
180
.github/workflows/build.yml
vendored
@ -200,23 +200,23 @@ jobs:
|
|||||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||||
|
|
||||||
|
|
||||||
freeBSD-latest:
|
# freeBSD-latest:
|
||||||
runs-on: macos-13
|
# runs-on: macos-13
|
||||||
|
#
|
||||||
steps:
|
# steps:
|
||||||
- name: Clone
|
# - name: Clone
|
||||||
uses: actions/checkout@v4
|
# uses: actions/checkout@v4
|
||||||
|
#
|
||||||
- name: Build
|
# - name: Build
|
||||||
uses: cross-platform-actions/action@v0.27.0
|
# uses: cross-platform-actions/action@v0.27.0
|
||||||
with:
|
# with:
|
||||||
operating_system: freebsd
|
# operating_system: freebsd
|
||||||
version: '14.2'
|
# version: '14.2'
|
||||||
run: |
|
# run: |
|
||||||
sudo pkg update
|
# sudo pkg update
|
||||||
sudo pkg install -y gmake sdl2 cmake git
|
# sudo pkg install -y gmake sdl2 cmake git
|
||||||
cmake -B build
|
# cmake -B build
|
||||||
cmake --build build --config Release
|
# cmake --build build --config Release
|
||||||
|
|
||||||
ubuntu-22-gcc:
|
ubuntu-22-gcc:
|
||||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||||
@ -561,6 +561,7 @@ jobs:
|
|||||||
run: >
|
run: >
|
||||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||||
|
-DBUILD_SHARED_LIBS=ON
|
||||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
@ -572,12 +573,37 @@ jobs:
|
|||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||||
|
|
||||||
- name: Upload dll
|
- name: Upload SDL2.dll
|
||||||
|
if: matrix.sdl2 == 'ON'
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: ${{ matrix.jnaPath }}_whisper.dll
|
name: ${{ matrix.s2arc }}_SDL2.dll
|
||||||
|
path: build/bin/${{ matrix.build }}/SDL2.dll
|
||||||
|
|
||||||
|
- name: Upload whisper dll
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: whisper_${{ matrix.arch }}.dll
|
||||||
path: build/bin/${{ matrix.build }}/whisper.dll
|
path: build/bin/${{ matrix.build }}/whisper.dll
|
||||||
|
|
||||||
|
- name: Upload ggml dll
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ggml_${{ matrix.arch }}.dll
|
||||||
|
path: build/bin/${{ matrix.build }}/ggml.dll
|
||||||
|
|
||||||
|
- name: Upload ggml base dll
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ggml_base_${{ matrix.arch }}.dll
|
||||||
|
path: build/bin/${{ matrix.build }}/ggml-base.dll
|
||||||
|
|
||||||
|
- name: Upload ggml cpu dll
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ggml_cpu_${{ matrix.arch }}.dll
|
||||||
|
path: build/bin/${{ matrix.build }}/ggml-cpu.dll
|
||||||
|
|
||||||
- name: Upload binaries
|
- name: Upload binaries
|
||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
@ -938,7 +964,7 @@ jobs:
|
|||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
|
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
|
||||||
name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework
|
name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
|
||||||
|
|
||||||
android:
|
android:
|
||||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||||
@ -996,38 +1022,88 @@ jobs:
|
|||||||
chmod +x ./gradlew
|
chmod +x ./gradlew
|
||||||
./gradlew assembleRelease
|
./gradlew assembleRelease
|
||||||
|
|
||||||
# TODO: disabled because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/9686220096/job/26735899598
|
bindings-java:
|
||||||
# java:
|
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||||
# needs: [ 'windows' ]
|
github.event.inputs.run_type == 'full-ci' }}
|
||||||
# runs-on: windows-latest
|
needs: ['windows']
|
||||||
# steps:
|
runs-on: windows-latest
|
||||||
# - uses: actions/checkout@v4
|
steps:
|
||||||
#
|
- uses: actions/checkout@v4
|
||||||
# - name: Install Java
|
|
||||||
# uses: actions/setup-java@v4
|
- name: Install Java
|
||||||
# with:
|
uses: actions/setup-java@v4
|
||||||
# distribution: zulu
|
with:
|
||||||
# java-version: 20
|
distribution: zulu
|
||||||
#
|
java-version: 20
|
||||||
# - name: Download Windows lib
|
|
||||||
# uses: actions/download-artifact@v4
|
- name: Download Whisper Windows lib
|
||||||
# with:
|
uses: actions/download-artifact@v4
|
||||||
# name: win32-x86-64_whisper.dll
|
with:
|
||||||
# path: bindings/java/build/generated/resources/main/win32-x86-64
|
name: whisper_x64.dll
|
||||||
#
|
|
||||||
# - name: Build
|
- name: Download GGML Windows lib
|
||||||
# run: |
|
uses: actions/download-artifact@v4
|
||||||
# models\download-ggml-model.cmd tiny.en
|
with:
|
||||||
# cd bindings/java
|
name: ggml_x64.dll
|
||||||
# chmod +x ./gradlew
|
|
||||||
# ./gradlew build
|
- name: Download GGML Base Windows lib
|
||||||
#
|
uses: actions/download-artifact@v4
|
||||||
# - name: Upload jar
|
with:
|
||||||
# uses: actions/upload-artifact@v4
|
name: ggml_base_x64.dll
|
||||||
# with:
|
|
||||||
# name: whispercpp.jar
|
- name: Download GGML CPU Windows lib
|
||||||
# path: bindings/java/build/libs/whispercpp-*.jar
|
uses: actions/download-artifact@v4
|
||||||
#
|
with:
|
||||||
|
name: ggml_cpu_x64.dll
|
||||||
|
|
||||||
|
- name: Download SDL2.dll
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: x64_SDL2.dll
|
||||||
|
|
||||||
|
- name: List downloaded files
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
Get-ChildItem -Path "." -Recurse -Filter "*.dll"
|
||||||
|
|
||||||
|
- name: Move DLL to correct location
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
New-Item -Path "build\bin\Release" -ItemType Directory -Force
|
||||||
|
|
||||||
|
Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force
|
||||||
|
Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory"
|
||||||
|
|
||||||
|
Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force
|
||||||
|
Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory"
|
||||||
|
|
||||||
|
Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force
|
||||||
|
Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory"
|
||||||
|
|
||||||
|
Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force
|
||||||
|
Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory"
|
||||||
|
|
||||||
|
Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force
|
||||||
|
Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory"
|
||||||
|
|
||||||
|
- name: List build release files
|
||||||
|
shell: pwsh
|
||||||
|
run: |
|
||||||
|
Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll"
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
models\download-ggml-model.cmd tiny.en models/
|
||||||
|
cd bindings/java
|
||||||
|
chmod +x ./gradlew
|
||||||
|
./gradlew build --info
|
||||||
|
|
||||||
|
- name: Upload jar
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: whispercpp.jar
|
||||||
|
path: bindings/java/build/libs/whispercpp-*.jar
|
||||||
|
|
||||||
# - name: Publish package
|
# - name: Publish package
|
||||||
# if: ${{ github.ref == 'refs/heads/master' }}
|
# if: ${{ github.ref == 'refs/heads/master' }}
|
||||||
# uses: gradle/gradle-build-action@v2.4.2
|
# uses: gradle/gradle-build-action@v2.4.2
|
||||||
|
1
.github/workflows/docker.yml
vendored
1
.github/workflows/docker.yml
vendored
@ -18,6 +18,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
config:
|
config:
|
||||||
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
|
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
|
||||||
|
- { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" }
|
||||||
#TODO: the cuda image keeps failing - disable for now
|
#TODO: the cuda image keeps failing - disable for now
|
||||||
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
|
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
|
||||||
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
||||||
|
0
.gitmodules
vendored
0
.gitmodules
vendored
@ -135,6 +135,22 @@ if (NOT TARGET ggml)
|
|||||||
add_library(ggml ALIAS ggml::ggml)
|
add_library(ggml ALIAS ggml::ggml)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(ggml)
|
add_subdirectory(ggml)
|
||||||
|
if(WIN32)
|
||||||
|
# The following adds a _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR macro and is a workaround for
|
||||||
|
# the Windows C++ standard library which does not support constexpr mutexes.
|
||||||
|
# From the release notes://github.com/microsoft/STL/wiki/Changelog
|
||||||
|
# Disable constexpr mutex constructor on Windows
|
||||||
|
# Fixed mutex's constructor to be constexpr. #3824 #4000 #4339
|
||||||
|
# Note: Programs that aren't following the documented restrictions on binary compatibility may encounter
|
||||||
|
# null dereferences in mutex machinery. You must follow this rule:
|
||||||
|
# When you mix binaries built by different supported versions of the toolset, the Redistributable version
|
||||||
|
# must be at least as new as the latest toolset used by any app component.
|
||||||
|
# You can define _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR as an escape hatch.
|
||||||
|
#
|
||||||
|
# Specifically to whisper.cpp this would cause a crash when using the Java bindings.
|
||||||
|
# resulting in a Invalid memory access error.
|
||||||
|
target_compile_definitions(ggml-base PRIVATE _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||||
endif()
|
endif()
|
||||||
@ -197,3 +213,36 @@ endif ()
|
|||||||
if (WHISPER_BUILD_EXAMPLES)
|
if (WHISPER_BUILD_EXAMPLES)
|
||||||
add_subdirectory(examples)
|
add_subdirectory(examples)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
set(MSVC_WARNING_FLAGS
|
||||||
|
/wd4101 # Unreferenced local variable
|
||||||
|
/wd4005 # Macro redefinition
|
||||||
|
/wd4065 # switch statement contains 'default' but no 'case' labels
|
||||||
|
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
|
||||||
|
/wd4244 # Conversion from one type to another type, possible loss of ata
|
||||||
|
/wd4805 # Unsafe mix of type
|
||||||
|
/wd4305 # Truncation from 'type1' to 'type2' (often double to float)
|
||||||
|
/wd4996 # Function or variable may be unsafe/deprecated
|
||||||
|
)
|
||||||
|
function(disable_msvc_warnings target_name)
|
||||||
|
if(TARGET ${target_name})
|
||||||
|
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
if (WHISPER_BUILD_EXAMPLES)
|
||||||
|
disable_msvc_warnings(whisper)
|
||||||
|
disable_msvc_warnings(common)
|
||||||
|
disable_msvc_warnings(common-sdl)
|
||||||
|
disable_msvc_warnings(lsp)
|
||||||
|
disable_msvc_warnings(wchess-core)
|
||||||
|
disable_msvc_warnings(whisper-command)
|
||||||
|
disable_msvc_warnings(whisper-cli)
|
||||||
|
disable_msvc_warnings(whisper-server)
|
||||||
|
disable_msvc_warnings(whisper-stream)
|
||||||
|
disable_msvc_warnings(whisper-talk-llama)
|
||||||
|
disable_msvc_warnings(whisper-bench)
|
||||||
|
disable_msvc_warnings(quantize)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
8
Makefile
8
Makefile
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
.PHONY: build
|
.PHONY: build
|
||||||
build:
|
build:
|
||||||
cmake -B build
|
cmake -B build $(CMAKE_ARGS)
|
||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
|
|
||||||
# download a few audio samples into folder "./samples":
|
# download a few audio samples into folder "./samples":
|
||||||
@ -41,17 +41,17 @@ samples:
|
|||||||
|
|
||||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo:
|
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo:
|
||||||
bash ./models/download-ggml-model.sh $@
|
bash ./models/download-ggml-model.sh $@
|
||||||
cmake -B build
|
cmake -B build $(CMAKE_ARGS)
|
||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "==============================================="
|
@echo "==============================================="
|
||||||
@echo "Running $@ on all samples in ./samples ..."
|
@echo "Running $@ on all samples in ./samples ..."
|
||||||
@echo "==============================================="
|
@echo "==============================================="
|
||||||
@echo ""
|
@echo ""
|
||||||
@for f in samples/*$(.flac .mp3 .ogg .wav); do \
|
@for f in samples/*.{flac,mp3,ogg,wav}; do \
|
||||||
echo "----------------------------------------------" ; \
|
echo "----------------------------------------------" ; \
|
||||||
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
|
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
|
||||||
echo "----------------------------------------------" ; \
|
echo "----------------------------------------------" ; \
|
||||||
echo "" ; \
|
echo "" ; \
|
||||||
./build/bin/whisper-cli -m models/ggml-$@.bin -f $$f ; \
|
./build/bin/whisper-cli -m models/ggml-$@.bin -f $$f ; \
|
||||||
echo "" ; \
|
echo "" ; \
|
||||||
|
156
README.md
156
README.md
@ -2,15 +2,12 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
[](https://github.com/ggerganov/whisper.cpp/actions)
|
[](https://github.com/ggml-org/whisper.cpp/actions)
|
||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
[](https://conan.io/center/whisper-cpp)
|
[](https://conan.io/center/whisper-cpp)
|
||||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||||
|
|
||||||
> [!NOTE]
|
Stable: [v1.7.5](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.5) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
|
||||||
> New maintenance roadmap: https://github.com/ggerganov/whisper.cpp/discussions/2788
|
|
||||||
|
|
||||||
Stable: [v1.7.5](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.5) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
|
||||||
|
|
||||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||||
|
|
||||||
@ -26,7 +23,8 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
|||||||
- [Efficient GPU support for NVIDIA](#nvidia-gpu-support)
|
- [Efficient GPU support for NVIDIA](#nvidia-gpu-support)
|
||||||
- [OpenVINO Support](#openvino-support)
|
- [OpenVINO Support](#openvino-support)
|
||||||
- [Ascend NPU Support](#ascend-npu-support)
|
- [Ascend NPU Support](#ascend-npu-support)
|
||||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/include/whisper.h)
|
- [Moore Threads GPU Support](#moore-threads-gpu-support)
|
||||||
|
- [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
|
||||||
|
|
||||||
Supported platforms:
|
Supported platforms:
|
||||||
|
|
||||||
@ -34,14 +32,14 @@ Supported platforms:
|
|||||||
- [x] [iOS](examples/whisper.objc)
|
- [x] [iOS](examples/whisper.objc)
|
||||||
- [x] [Android](examples/whisper.android)
|
- [x] [Android](examples/whisper.android)
|
||||||
- [x] [Java](bindings/java/README.md)
|
- [x] [Java](bindings/java/README.md)
|
||||||
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
|
- [x] Linux / [FreeBSD](https://github.com/ggml-org/whisper.cpp/issues/56#issuecomment-1350920264)
|
||||||
- [x] [WebAssembly](examples/whisper.wasm)
|
- [x] [WebAssembly](examples/whisper.wasm)
|
||||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168)]
|
||||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
- [x] [Raspberry Pi](https://github.com/ggml-org/whisper.cpp/discussions/166)
|
||||||
- [x] [Docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
|
- [x] [Docker](https://github.com/ggml-org/whisper.cpp/pkgs/container/whisper.cpp)
|
||||||
|
|
||||||
The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp).
|
The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp).
|
||||||
The rest of the code is part of the [`ggml`](https://github.com/ggerganov/ggml) machine learning library.
|
The rest of the code is part of the [`ggml`](https://github.com/ggml-org/ggml) machine learning library.
|
||||||
|
|
||||||
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
||||||
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
||||||
@ -54,14 +52,14 @@ https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a
|
|||||||
|
|
||||||
On Apple Silicon, the inference runs fully on the GPU via Metal:
|
On Apple Silicon, the inference runs fully on the GPU via Metal:
|
||||||
|
|
||||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
|
https://github.com/ggml-org/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|
||||||
First clone the repository:
|
First clone the repository:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
git clone https://github.com/ggml-org/whisper.cpp.git
|
||||||
```
|
```
|
||||||
|
|
||||||
Navigate into the directory:
|
Navigate into the directory:
|
||||||
@ -152,6 +150,7 @@ standard cmake setup with:
|
|||||||
cmake -B build -DGGML_BLAS=1
|
cmake -B build -DGGML_BLAS=1
|
||||||
cmake --build build --config Release
|
cmake --build build --config Release
|
||||||
./build/bin/whisper-cli [ .. etc .. ]
|
./build/bin/whisper-cli [ .. etc .. ]
|
||||||
|
```
|
||||||
|
|
||||||
## Quantization
|
## Quantization
|
||||||
|
|
||||||
@ -225,7 +224,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
|
|||||||
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
|
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
|
||||||
Next runs are faster.
|
Next runs are faster.
|
||||||
|
|
||||||
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
|
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggml-org/whisper.cpp/pull/566).
|
||||||
|
|
||||||
## OpenVINO support
|
## OpenVINO support
|
||||||
|
|
||||||
@ -310,7 +309,7 @@ This can result in significant speedup in encoder performance. Here are the inst
|
|||||||
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
||||||
cached for the next run.
|
cached for the next run.
|
||||||
|
|
||||||
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggml-org/whisper.cpp/pull/1037).
|
||||||
|
|
||||||
## NVIDIA GPU support
|
## NVIDIA GPU support
|
||||||
|
|
||||||
@ -324,6 +323,12 @@ cmake -B build -DGGML_CUDA=1
|
|||||||
cmake --build build -j --config Release
|
cmake --build build -j --config Release
|
||||||
```
|
```
|
||||||
|
|
||||||
|
or for newer NVIDIA GPU's (RTX 5000 series):
|
||||||
|
```
|
||||||
|
cmake -B build -DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"
|
||||||
|
cmake --build build -j --config Release
|
||||||
|
```
|
||||||
|
|
||||||
## Vulkan GPU support
|
## Vulkan GPU support
|
||||||
Cross-vendor solution which allows you to accelerate workload on your GPU.
|
Cross-vendor solution which allows you to accelerate workload on your GPU.
|
||||||
First, make sure your graphics card driver provides support for Vulkan API.
|
First, make sure your graphics card driver provides support for Vulkan API.
|
||||||
@ -377,6 +382,56 @@ Run the inference examples as usual, for example:
|
|||||||
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
|
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
|
||||||
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
|
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
|
||||||
|
|
||||||
|
## Moore Threads GPU support
|
||||||
|
|
||||||
|
With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels.
|
||||||
|
First, make sure you have installed `MUSA SDK rc3.1.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=rc3.1.1
|
||||||
|
|
||||||
|
Now build `whisper.cpp` with MUSA support:
|
||||||
|
|
||||||
|
```
|
||||||
|
cmake -B build -DGGML_MUSA=1
|
||||||
|
cmake --build build -j --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
or specify the architecture for your Moore Threads GPU. For example, if you have a MTT S80 GPU, you can specify the architecture as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21"
|
||||||
|
cmake --build build -j --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
## FFmpeg support (Linux only)
|
||||||
|
|
||||||
|
If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration.
|
||||||
|
|
||||||
|
First, you need to install required libraries:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Debian/Ubuntu
|
||||||
|
sudo apt install libavcodec-dev libavformat-dev libavutil-dev
|
||||||
|
|
||||||
|
# RHEL/Fedora
|
||||||
|
sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-devel
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can build the project as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build -D WHISPER_FFMPEG=yes
|
||||||
|
cmake --build build
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the following example to confirm it's working:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Convert an audio file to Opus format
|
||||||
|
ffmpeg -i samples/jfk.wav jfk.opus
|
||||||
|
|
||||||
|
# Transcribe the audio file
|
||||||
|
./build/bin/whisper-cli --model models/ggml-base.en.bin --file jfk.opus
|
||||||
|
```
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
@ -388,8 +443,9 @@ Run the inference examples as usual, for example:
|
|||||||
|
|
||||||
We have two Docker images available for this project:
|
We have two Docker images available for this project:
|
||||||
|
|
||||||
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
|
1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
|
||||||
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
|
2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||||
|
3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`)
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
@ -402,11 +458,11 @@ docker run -it --rm \
|
|||||||
docker run -it --rm \
|
docker run -it --rm \
|
||||||
-v path/to/models:/models \
|
-v path/to/models:/models \
|
||||||
-v path/to/audios:/audios \
|
-v path/to/audios:/audios \
|
||||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
|
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav"
|
||||||
# transcribe an audio file in samples folder
|
# transcribe an audio file in samples folder
|
||||||
docker run -it --rm \
|
docker run -it --rm \
|
||||||
-v path/to/models:/models \
|
-v path/to/models:/models \
|
||||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
|
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installing with Conan
|
## Installing with Conan
|
||||||
@ -427,8 +483,8 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
|
|||||||
|
|
||||||
This is a naive example of performing real-time inference on audio from your microphone.
|
This is a naive example of performing real-time inference on audio from your microphone.
|
||||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
||||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
More info is available in [issue #10](https://github.com/ggml-org/whisper.cpp/issues/10).
|
||||||
You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly.
|
You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmake -B build -DWHISPER_SDL2=ON
|
cmake -B build -DWHISPER_SDL2=ON
|
||||||
@ -516,7 +572,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
|
|||||||
|
|
||||||
## Speaker segmentation via tinydiarize (experimental)
|
## Speaker segmentation via tinydiarize (experimental)
|
||||||
|
|
||||||
More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
|
More information about this approach is available here: https://github.com/ggml-org/whisper.cpp/pull/1058
|
||||||
|
|
||||||
Sample usage:
|
Sample usage:
|
||||||
|
|
||||||
@ -580,7 +636,7 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
|
|||||||
|
|
||||||
## Video comparison of different models
|
## Video comparison of different models
|
||||||
|
|
||||||
Use the [scripts/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format:
|
Use the [scripts/bench-wts.sh](https://github.com/ggml-org/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./scripts/bench-wts.sh samples/jfk.wav
|
./scripts/bench-wts.sh samples/jfk.wav
|
||||||
@ -597,7 +653,7 @@ In order to have an objective comparison of the performance of the inference acr
|
|||||||
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
|
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
|
||||||
took to execute it. The results are summarized in the following Github issue:
|
took to execute it. The results are summarized in the following Github issue:
|
||||||
|
|
||||||
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
|
[Benchmark results](https://github.com/ggml-org/whisper.cpp/issues/89)
|
||||||
|
|
||||||
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py).
|
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py).
|
||||||
|
|
||||||
@ -624,25 +680,24 @@ You can download the converted models using the [models/download-ggml-model.sh](
|
|||||||
or manually from here:
|
or manually from here:
|
||||||
|
|
||||||
- https://huggingface.co/ggerganov/whisper.cpp
|
- https://huggingface.co/ggerganov/whisper.cpp
|
||||||
- https://ggml.ggerganov.com
|
|
||||||
|
|
||||||
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md).
|
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md).
|
||||||
|
|
||||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
## [Bindings](https://github.com/ggml-org/whisper.cpp/discussions/categories/bindings)
|
||||||
|
|
||||||
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggml-org/whisper.cpp/discussions/310)
|
||||||
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggml-org/whisper.cpp/discussions/309)
|
||||||
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
||||||
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggml-org/whisper.cpp/discussions/312)
|
||||||
- [x] Java:
|
- [x] Java:
|
||||||
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
|
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
|
||||||
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
|
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggml-org/whisper.cpp/discussions/507)
|
||||||
- [x] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
|
- [x] Objective-C / Swift: [ggml-org/whisper.spm](https://github.com/ggml-org/whisper.spm) | [#313](https://github.com/ggml-org/whisper.cpp/discussions/313)
|
||||||
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
|
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
|
||||||
- [x] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
|
- [x] .NET: | [#422](https://github.com/ggml-org/whisper.cpp/discussions/422)
|
||||||
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
|
||||||
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
|
||||||
- [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
|
- [x] Python: | [#9](https://github.com/ggml-org/whisper.cpp/issues/9)
|
||||||
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
|
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
|
||||||
- [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp)
|
- [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp)
|
||||||
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
|
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
|
||||||
@ -650,6 +705,33 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
|
|||||||
- [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
|
- [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
|
||||||
- [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
|
- [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
|
||||||
|
|
||||||
|
## XCFramework
|
||||||
|
The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS,
|
||||||
|
and macOS. It can be used in Swift projects without the need to compile the
|
||||||
|
library from source. For examples:
|
||||||
|
```swift
|
||||||
|
// swift-tools-version: 5.10
|
||||||
|
// The swift-tools-version declares the minimum version of Swift required to build this package.
|
||||||
|
|
||||||
|
import PackageDescription
|
||||||
|
|
||||||
|
let package = Package(
|
||||||
|
name: "Whisper",
|
||||||
|
targets: [
|
||||||
|
.executableTarget(
|
||||||
|
name: "Whisper",
|
||||||
|
dependencies: [
|
||||||
|
"WhisperFramework"
|
||||||
|
]),
|
||||||
|
.binaryTarget(
|
||||||
|
name: "WhisperFramework",
|
||||||
|
url: "https://github.com/ggml-org/whisper.cpp/releases/download/v1.7.5/whisper-v1.7.5-xcframework.zip",
|
||||||
|
checksum: "c7faeb328620d6012e130f3d705c51a6ea6c995605f2df50f6e1ad68c59c6c4a"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
There are various examples of using the library for different projects in the [examples](examples) folder.
|
There are various examples of using the library for different projects in the [examples](examples) folder.
|
||||||
@ -668,13 +750,13 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
|||||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggml-org/whisper.cpp/issues/185) |
|
||||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||||
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
|
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
|
||||||
|
|
||||||
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
## [Discussions](https://github.com/ggml-org/whisper.cpp/discussions)
|
||||||
|
|
||||||
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
|
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
|
||||||
You can use the [Show and tell](https://github.com/ggerganov/whisper.cpp/discussions/categories/show-and-tell) category
|
You can use the [Show and tell](https://github.com/ggml-org/whisper.cpp/discussions/categories/show-and-tell) category
|
||||||
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
|
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
|
||||||
[Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126) discussion.
|
[Frequently asked questions (#126)](https://github.com/ggml-org/whisper.cpp/discussions/126) discussion.
|
||||||
|
@ -51,7 +51,7 @@ func main() {
|
|||||||
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
git clone https://github.com/ggml-org/whisper.cpp.git
|
||||||
cd whisper.cpp/bindings/go
|
cd whisper.cpp/bindings/go
|
||||||
make test
|
make test
|
||||||
```
|
```
|
||||||
@ -98,7 +98,7 @@ The API Documentation:
|
|||||||
|
|
||||||
Getting help:
|
Getting help:
|
||||||
|
|
||||||
* Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
* Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312)
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
github.com/ggerganov/whisper.cpp/bindings/go
|
github.com/ggml-org/whisper.cpp/bindings/go
|
||||||
provides a speech-to-text service bindings for the Go programming language.
|
provides a speech-to-text service bindings for the Go programming language.
|
||||||
*/
|
*/
|
||||||
package whisper
|
package whisper
|
||||||
|
@ -31,10 +31,10 @@ public class Example {
|
|||||||
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||||
// custom configuration if required
|
// custom configuration if required
|
||||||
whisperParams.temperature_inc = 0f;
|
whisperParams.temperature_inc = 0f;
|
||||||
|
|
||||||
var samples = readAudio(); // divide each value by 32767.0f
|
var samples = readAudio(); // divide each value by 32767.0f
|
||||||
whisper.fullTranscribe(whisperParams, samples);
|
whisper.fullTranscribe(whisperParams, samples);
|
||||||
|
|
||||||
int segmentCount = whisper.getTextSegmentCount(context);
|
int segmentCount = whisper.getTextSegmentCount(context);
|
||||||
for (int i = 0; i < segmentCount; i++) {
|
for (int i = 0; i < segmentCount; i++) {
|
||||||
String text = whisper.getTextSegment(context, i);
|
String text = whisper.getTextSegment(context, i);
|
||||||
@ -52,7 +52,7 @@ public class Example {
|
|||||||
In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
|
In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
git clone https://github.com/ggml-org/whisper.cpp.git
|
||||||
cd whisper.cpp/bindings/java
|
cd whisper.cpp/bindings/java
|
||||||
|
|
||||||
./gradlew build
|
./gradlew build
|
||||||
|
@ -27,23 +27,41 @@ sourceSets {
|
|||||||
tasks.register('copyLibwhisperDynlib', Copy) {
|
tasks.register('copyLibwhisperDynlib', Copy) {
|
||||||
from '../../build/src'
|
from '../../build/src'
|
||||||
include 'libwhisper.dylib'
|
include 'libwhisper.dylib'
|
||||||
into 'build/generated/resources/main/darwin'
|
into 'build/generated/resources/main'
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.register('copyLibwhisperSo', Copy) {
|
tasks.register('copyLibwhisperSo', Copy) {
|
||||||
from '../../build/src'
|
from '../../build/src'
|
||||||
include 'libwhisper.so'
|
include 'libwhisper.so'
|
||||||
into 'build/generated/resources/main/linux-x86-64'
|
into 'build/generated/resources/main'
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.register('copyWhisperDll', Copy) {
|
tasks.register('copyWhisperDLL', Copy) {
|
||||||
from '../../build/Release'
|
from '../../build/bin/Release'
|
||||||
include 'whisper.dll'
|
include 'whisper.dll'
|
||||||
into 'build/generated/resources/main/windows-x86-64'
|
into 'build/generated/resources/main'
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.register('copyGGML_BASE_DLL', Copy) {
|
||||||
|
from '../../build/bin/Release'
|
||||||
|
include 'ggml-base.dll'
|
||||||
|
into 'build/generated/resources/main'
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.register('copyGGML_DLL', Copy) {
|
||||||
|
from '../../build/bin/Release'
|
||||||
|
include 'ggml.dll'
|
||||||
|
into 'build/generated/resources/main'
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.register('copyGGML_CPU_DLL', Copy) {
|
||||||
|
from '../../build/bin/Release'
|
||||||
|
include 'ggml-cpu.dll'
|
||||||
|
into 'build/generated/resources/main'
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.register('copyLibs') {
|
tasks.register('copyLibs') {
|
||||||
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
|
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDLL, copyGGML_BASE_DLL, copyGGML_DLL, copyGGML_CPU_DLL
|
||||||
}
|
}
|
||||||
|
|
||||||
test {
|
test {
|
||||||
|
@ -9,6 +9,7 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
|||||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||||
|
|
||||||
public interface WhisperCppJnaLibrary extends Library {
|
public interface WhisperCppJnaLibrary extends Library {
|
||||||
|
|
||||||
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
|
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
|
||||||
|
|
||||||
String whisper_print_system_info();
|
String whisper_print_system_info();
|
||||||
|
3
bindings/ruby/.gitignore
vendored
3
bindings/ruby/.gitignore
vendored
@ -1,3 +1,6 @@
|
|||||||
LICENSE
|
LICENSE
|
||||||
pkg/
|
pkg/
|
||||||
lib/whisper.*
|
lib/whisper.*
|
||||||
|
ext/sources/*
|
||||||
|
!ext/sources/CMakeGraphVizOptions.cmake
|
||||||
|
ext/mkmf.log
|
||||||
|
@ -16,6 +16,18 @@ If bundler is not being used to manage dependencies, install the gem by executin
|
|||||||
|
|
||||||
$ gem install whispercpp
|
$ gem install whispercpp
|
||||||
|
|
||||||
|
You can pass build options for whisper.cpp, for instance:
|
||||||
|
|
||||||
|
$ bundle config build.whispercpp --enable-ggml-cuda
|
||||||
|
|
||||||
|
or,
|
||||||
|
|
||||||
|
$ gem install whispercpp -- --enable-ggml-cuda
|
||||||
|
|
||||||
|
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options.
|
||||||
|
For boolean options like `GGML_CUDA`, the README says `-DGGML_CUDA=1`. You need strip `-D`, prepend `--enable-` for `1` or `ON` (`--disable-` for `0` or `OFF`) and make it kebab-case: `--enable-ggml-cuda`.
|
||||||
|
For options which require arguments like `CMAKE_CUDA_ARCHITECTURES`, the README says `-DCMAKE_CUDA_ARCHITECTURES="86"`. You need strip `-D`, prepend `--`, make it kebab-case, append `=` and append argument: `--cmake-cuda-architectures="86"`.
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
-----
|
-----
|
||||||
|
|
||||||
@ -228,7 +240,7 @@ The second argument `samples` may be an array, an object with `length` and `each
|
|||||||
Development
|
Development
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
% git clone https://github.com/ggerganov/whisper.cpp.git
|
% git clone https://github.com/ggml-org/whisper.cpp.git
|
||||||
% cd whisper.cpp/bindings/ruby
|
% cd whisper.cpp/bindings/ruby
|
||||||
% rake test
|
% rake test
|
||||||
|
|
||||||
@ -241,5 +253,5 @@ License
|
|||||||
|
|
||||||
The same to [whisper.cpp][].
|
The same to [whisper.cpp][].
|
||||||
|
|
||||||
[whisper.cpp]: https://github.com/ggerganov/whisper.cpp
|
[whisper.cpp]: https://github.com/ggml-org/whisper.cpp
|
||||||
[models]: https://github.com/ggerganov/whisper.cpp/tree/master/models
|
[models]: https://github.com/ggml-org/whisper.cpp/tree/master/models
|
||||||
|
@ -3,11 +3,15 @@ require "bundler/gem_tasks"
|
|||||||
require "rake/testtask"
|
require "rake/testtask"
|
||||||
require_relative "extsources"
|
require_relative "extsources"
|
||||||
|
|
||||||
|
SOURCES_DIR = "ext/sources"
|
||||||
|
|
||||||
SOURCES = FileList[]
|
SOURCES = FileList[]
|
||||||
|
|
||||||
EXTSOURCES.each do |src|
|
EXTSOURCES.each do |src|
|
||||||
basename = src.pathmap("%f")
|
basename = src.pathmap("%f")
|
||||||
dest = basename == "LICENSE" ? basename : src.pathmap("%{../..,ext}p")
|
dest = basename == "LICENSE" ? basename
|
||||||
|
: src.pathmap("%{\\.\\./\\.\\.,#{SOURCES_DIR}}p")
|
||||||
|
.pathmap("%{\\.\\./javascript,#{SOURCES_DIR}/bindings/javascript}p")
|
||||||
dir = dest.pathmap("%d")
|
dir = dest.pathmap("%d")
|
||||||
file src
|
file src
|
||||||
directory dir
|
directory dir
|
||||||
@ -18,7 +22,6 @@ EXTSOURCES.each do |src|
|
|||||||
end
|
end
|
||||||
|
|
||||||
CLEAN.include SOURCES
|
CLEAN.include SOURCES
|
||||||
CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
|
|
||||||
|
|
||||||
SRC = FileList["ext/*.{c,cpp,h}"]
|
SRC = FileList["ext/*.{c,cpp,h}"]
|
||||||
|
|
||||||
@ -36,6 +39,20 @@ file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
|
|||||||
ruby "extconf.rb"
|
ruby "extconf.rb"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
if File.exist? "ext/Makefile"
|
||||||
|
task :make_clean do
|
||||||
|
cd "ext" do
|
||||||
|
sh "make", "clean"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
task clean: :make_clean
|
||||||
|
task :make_distclean do
|
||||||
|
cd "ext" do
|
||||||
|
sh "make", "distclean"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
task clobber: :make_distclean
|
||||||
|
end
|
||||||
|
|
||||||
file SO_FILE => "ext/Makefile" do |t|
|
file SO_FILE => "ext/Makefile" do |t|
|
||||||
chdir "ext" do
|
chdir "ext" do
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
ggml/src/ggml-cpu/ggml-cpu-cpp.o: \
|
|
||||||
ggml/src/ggml-cpu/ggml-cpu.cpp \
|
|
||||||
ggml/src/ggml-cpu/unary-ops.cpp \
|
|
||||||
ggml/src/ggml-cpu/binary-ops.cpp \
|
|
||||||
ggml/include/ggml-backend.h \
|
|
||||||
ggml/include/ggml.h \
|
|
||||||
ggml/include/ggml-alloc.h \
|
|
||||||
ggml/src/ggml-backend-impl.h \
|
|
||||||
ggml/include/ggml-cpu.h \
|
|
||||||
ggml/src/ggml-impl.h
|
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
|
61
bindings/ruby/ext/dependencies.rb
Normal file
61
bindings/ruby/ext/dependencies.rb
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
require "tsort"
|
||||||
|
|
||||||
|
class Dependencies
|
||||||
|
def initialize(cmake, options)
|
||||||
|
@cmake = cmake
|
||||||
|
@options = options
|
||||||
|
|
||||||
|
generate_dot
|
||||||
|
@libs = parse_dot
|
||||||
|
end
|
||||||
|
|
||||||
|
def to_s
|
||||||
|
@libs.join(" ")
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def dot_path
|
||||||
|
File.join(__dir__, "build", "whisper.cpp.dot")
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate_dot
|
||||||
|
system @cmake, "-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF", @options.to_s, exception: true
|
||||||
|
end
|
||||||
|
|
||||||
|
def parse_dot
|
||||||
|
static_lib_shape = nil
|
||||||
|
nodes = {}
|
||||||
|
depends = Hash.new {|h, k| h[k] = []}
|
||||||
|
|
||||||
|
class << depends
|
||||||
|
include TSort
|
||||||
|
alias tsort_each_node each_key
|
||||||
|
def tsort_each_child(node, &block)
|
||||||
|
fetch(node, []).each(&block)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
File.open(dot_path).each_line do |line|
|
||||||
|
case line
|
||||||
|
when /\[\s*label\s*=\s*"Static Library"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]/
|
||||||
|
static_lib_shape = $~[:shape]
|
||||||
|
when /\A\s*"(?<node>\w+)"\s*\[\s*label\s*=\s*"(?<label>\S+)"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]\s*;\s*\z/
|
||||||
|
node = $~[:node]
|
||||||
|
label = $~[:label]
|
||||||
|
shape = $~[:shape]
|
||||||
|
nodes[node] = [label, shape]
|
||||||
|
when /\A\s*"(?<depender>\w+)"\s*->\s*"(?<dependee>\w+)"/
|
||||||
|
depender = $~[:depender]
|
||||||
|
dependee = $~[:dependee]
|
||||||
|
depends[depender] ||= []
|
||||||
|
depends[depender] << dependee
|
||||||
|
end
|
||||||
|
end
|
||||||
|
depends.tsort.filter_map {|node|
|
||||||
|
label, shape = nodes[node]
|
||||||
|
shape == static_lib_shape ? label : nil
|
||||||
|
}.collect {|lib| "lib#{lib}.a"}
|
||||||
|
.reverse
|
||||||
|
end
|
||||||
|
end
|
@ -1,210 +1,22 @@
|
|||||||
require 'mkmf'
|
require "mkmf"
|
||||||
|
require_relative "options"
|
||||||
|
require_relative "dependencies"
|
||||||
|
|
||||||
# need to use c++ compiler flags
|
cmake = find_executable("cmake") || abort
|
||||||
$CXXFLAGS << ' -std=c++17'
|
options = Options.new
|
||||||
|
have_library("gomp") rescue nil
|
||||||
|
libs = Dependencies.new(cmake, options)
|
||||||
|
|
||||||
$LDFLAGS << ' -lstdc++'
|
$INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples"
|
||||||
|
$LOCAL_LIBS << " #{libs}"
|
||||||
|
$cleanfiles << " build #{libs}"
|
||||||
|
|
||||||
# Set to true when building binary gems
|
create_makefile "whisper" do |conf|
|
||||||
if enable_config('static-stdlib', false)
|
conf << <<~EOF
|
||||||
$LDFLAGS << ' -static-libgcc -static-libstdc++'
|
$(TARGET_SO): #{libs}
|
||||||
end
|
#{libs}: cmake-targets
|
||||||
|
cmake-targets:
|
||||||
if enable_config('march-tune-native', false)
|
#{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options}
|
||||||
$CFLAGS << ' -march=native -mtune=native'
|
#{"\t"}#{cmake} --build build --config Release --target common whisper
|
||||||
$CXXFLAGS << ' -march=native -mtune=native'
|
EOF
|
||||||
end
|
|
||||||
|
|
||||||
if ENV['WHISPER_METAL']
|
|
||||||
$GGML_METAL ||= true
|
|
||||||
$DEPRECATE_WARNING ||= true
|
|
||||||
end
|
|
||||||
|
|
||||||
$UNAME_S = `uname -s`.chomp
|
|
||||||
$UNAME_P = `uname -p`.chomp
|
|
||||||
$UNAME_M = `uname -m`.chomp
|
|
||||||
|
|
||||||
if $UNAME_S == 'Darwin'
|
|
||||||
unless ENV['GGML_NO_METAL']
|
|
||||||
$GGML_METAL ||= true
|
|
||||||
end
|
|
||||||
$GGML_NO_OPENMP ||= true
|
|
||||||
end
|
|
||||||
|
|
||||||
if $GGML_METAL
|
|
||||||
$GGML_METAL_EMBED_LIBRARY = true
|
|
||||||
end
|
|
||||||
|
|
||||||
$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples -DGGML_USE_CPU'
|
|
||||||
$MK_CFLAGS = '-std=c11 -fPIC'
|
|
||||||
$MK_CXXFLAGS = '-std=c++17 -fPIC'
|
|
||||||
$MK_NVCCFLAGS = '-std=c++17'
|
|
||||||
$MK_LDFLAGS = ''
|
|
||||||
|
|
||||||
$OBJ_GGML = []
|
|
||||||
$OBJ_WHISPER = []
|
|
||||||
$OBJ_COMMON = []
|
|
||||||
$OBJ_SDL = []
|
|
||||||
|
|
||||||
$MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600'
|
|
||||||
|
|
||||||
if $UNAME_S == 'Linux'
|
|
||||||
$MK_CPPFLAGS << ' -D_GNU_SOURCE'
|
|
||||||
end
|
|
||||||
|
|
||||||
if $UNAME_S == 'Darwin'
|
|
||||||
$MK_CPPFLAGS << ' -D_DARWIN_C_SOURCE'
|
|
||||||
end
|
|
||||||
|
|
||||||
if ENV['WHISPER_DEBUG']
|
|
||||||
$MK_CFLAGS << ' -O0 -g'
|
|
||||||
$MK_CXXFLAGS << ' -O0 -g'
|
|
||||||
$MK_LDFLAGS << ' -g'
|
|
||||||
$MK_NVCCFLAGS << ' -O0 -g'
|
|
||||||
else
|
|
||||||
$MK_CPPFLAGS << ' -DNDEBUG'
|
|
||||||
$MK_CFLAGS << ' -O3'
|
|
||||||
$MK_CXXFLAGS << ' -O3'
|
|
||||||
$MK_NVCCFLAGS << ' -O3'
|
|
||||||
end
|
|
||||||
|
|
||||||
$WARN_FLAGS =
|
|
||||||
' -Wall' <<
|
|
||||||
' -Wextra' <<
|
|
||||||
' -Wpedantic' <<
|
|
||||||
' -Wcast-qual' <<
|
|
||||||
' -Wno-unused-function'
|
|
||||||
|
|
||||||
$MK_CFLAGS <<
|
|
||||||
$WARN_FLAGS <<
|
|
||||||
' -Wshadow' <<
|
|
||||||
' -Wstrict-prototypes' <<
|
|
||||||
' -Wpointer-arith' <<
|
|
||||||
' -Wmissing-prototypes' <<
|
|
||||||
' -Werror=implicit-int' <<
|
|
||||||
' -Werror=implicit-function-declaration'
|
|
||||||
|
|
||||||
$MK_CXXFLAGS <<
|
|
||||||
$WARN_FLAGS <<
|
|
||||||
' -Wmissing-declarations' <<
|
|
||||||
' -Wmissing-noreturn'
|
|
||||||
|
|
||||||
unless `#{cc_command} #{$LDFLAGS} -Wl,-v 2>&1`.chomp.include? 'dyld-1015.7'
|
|
||||||
$MK_CPPFLAGS << ' -DHAVE_BUGGY_APPLE_LINKER'
|
|
||||||
end
|
|
||||||
|
|
||||||
if %w[Linux Darwin FreeBSD NetBSD OpenBSD Haiku].include? $UNAME_S
|
|
||||||
$MK_CFLAGS << ' -pthread'
|
|
||||||
$MK_CXXFLAGS << ' -pthread'
|
|
||||||
end
|
|
||||||
|
|
||||||
unless $_WIN32
|
|
||||||
$DSO_EXT = '.so'
|
|
||||||
else
|
|
||||||
$DSO_EXT = '.dll'
|
|
||||||
end
|
|
||||||
|
|
||||||
unless ENV['RISCV']
|
|
||||||
if %w[x86_64 i686 amd64].include? $UNAME_M
|
|
||||||
$HOST_CXXFLAGS ||= ''
|
|
||||||
|
|
||||||
$MK_CFLAGS << ' -march=native -mtune=native'
|
|
||||||
$HOST_CXXFLAGS << ' -march=native -mtune=native'
|
|
||||||
end
|
|
||||||
else
|
|
||||||
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
|
|
||||||
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
|
|
||||||
end
|
|
||||||
|
|
||||||
unless ENV['GGML_NO_ACCELERATE']
|
|
||||||
if $UNAME_S == 'Darwin'
|
|
||||||
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE'
|
|
||||||
$MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
|
|
||||||
$MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
|
|
||||||
$MK_LDFLAGS << ' -framework Accelerate'
|
|
||||||
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if ENV['GGML_OPENBLAS']
|
|
||||||
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
|
|
||||||
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
|
|
||||||
$MK_LDFLAGS << " #{`pkg-config --libs openblas`}"
|
|
||||||
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
|
|
||||||
end
|
|
||||||
|
|
||||||
if ENV['GGML_OPENBLAS64']
|
|
||||||
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
|
|
||||||
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
|
|
||||||
$MK_LDFLAGS << " #{`pkg-config --libs openblas64`}"
|
|
||||||
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
|
|
||||||
end
|
|
||||||
|
|
||||||
if $GGML_METAL
|
|
||||||
$MK_CPPFLAGS << ' -DGGML_USE_METAL'
|
|
||||||
$MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
|
|
||||||
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal.o'
|
|
||||||
|
|
||||||
if ENV['GGML_METAL_NDEBUG']
|
|
||||||
$MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
|
|
||||||
end
|
|
||||||
|
|
||||||
if $GGML_METAL_EMBED_LIBRARY
|
|
||||||
$MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
|
|
||||||
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal-embed.o'
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
$OBJ_GGML <<
|
|
||||||
'ggml/src/ggml.o' <<
|
|
||||||
'ggml/src/ggml-alloc.o' <<
|
|
||||||
'ggml/src/ggml-backend.o' <<
|
|
||||||
'ggml/src/ggml-backend-reg.o' <<
|
|
||||||
'ggml/src/ggml-opt.o' <<
|
|
||||||
'ggml/src/ggml-quants.o' <<
|
|
||||||
'ggml/src/ggml-threading.o' <<
|
|
||||||
'ggml/src/ggml-cpu/ggml-cpu.o' <<
|
|
||||||
'ggml/src/ggml-cpu/ggml-cpu-cpp.o' <<
|
|
||||||
'ggml/src/ggml-cpu/ggml-cpu-aarch64.o' <<
|
|
||||||
'ggml/src/ggml-cpu/ggml-cpu-hbm.o' <<
|
|
||||||
'ggml/src/ggml-cpu/ggml-cpu-quants.o' <<
|
|
||||||
'ggml/src/ggml-cpu/ggml-cpu-traits.o' <<
|
|
||||||
'ggml/src/ggml-cpu/unary-ops.o' <<
|
|
||||||
'ggml/src/ggml-cpu/binary-ops.o'
|
|
||||||
|
|
||||||
$OBJ_WHISPER <<
|
|
||||||
'src/whisper.o' <<
|
|
||||||
'examples/common.o' <<
|
|
||||||
'examples/common-whisper.o'
|
|
||||||
|
|
||||||
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
|
|
||||||
$objs <<
|
|
||||||
"ruby_whisper.o" <<
|
|
||||||
"ruby_whisper_context.o" <<
|
|
||||||
"ruby_whisper_transcribe.o" <<
|
|
||||||
"ruby_whisper_params.o" <<
|
|
||||||
"ruby_whisper_error.o" <<
|
|
||||||
"ruby_whisper_segment.o" <<
|
|
||||||
"ruby_whisper_model.o"
|
|
||||||
|
|
||||||
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
|
|
||||||
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
|
|
||||||
$BASE_CXXFLAGS = "#{$MK_CXXFLAGS} #{$CXXFLAGS}"
|
|
||||||
$CXXFLAGS = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}"
|
|
||||||
$NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}"
|
|
||||||
$LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}"
|
|
||||||
|
|
||||||
create_makefile('whisper')
|
|
||||||
|
|
||||||
File.open 'Makefile', 'a' do |file|
|
|
||||||
file.puts 'include scripts/get-flags.mk'
|
|
||||||
file.puts 'include cpu.mk'
|
|
||||||
|
|
||||||
if $GGML_METAL
|
|
||||||
file.puts 'include metal.mk'
|
|
||||||
|
|
||||||
if $GGML_METAL_EMBED_LIBRARY
|
|
||||||
file.puts 'include metal-embed.mk'
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
ggml/src/ggml-metal/ggml-metal-embed.o: \
|
|
||||||
ggml/src/ggml-metal/ggml-metal.metal \
|
|
||||||
ggml/src/ggml-metal/ggml-metal-impl.h \
|
|
||||||
ggml/src/ggml-common.h
|
|
||||||
@echo "Embedding Metal library"
|
|
||||||
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
|
|
||||||
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
|
|
||||||
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
|
|
||||||
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
|
|
||||||
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
|
|
||||||
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
|
|
||||||
@echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
|
|
||||||
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
|
|
||||||
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
|
|
||||||
$(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@
|
|
||||||
@rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s
|
|
||||||
@rmdir ${TEMP_ASSEMBLY}
|
|
@ -1,6 +0,0 @@
|
|||||||
ggml/src/ggml-metal/ggml-metal.o: \
|
|
||||||
ggml/src/ggml-metal/ggml-metal.m \
|
|
||||||
ggml/src/ggml-metal/ggml-metal-impl.h \
|
|
||||||
ggml/include/ggml-metal.h \
|
|
||||||
ggml/include/ggml.h
|
|
||||||
$(CC) $(CFLAGS) -c $< -o $@
|
|
219
bindings/ruby/ext/options.rb
Normal file
219
bindings/ruby/ext/options.rb
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
class Options
|
||||||
|
def initialize
|
||||||
|
@options = {}
|
||||||
|
@pending_options = []
|
||||||
|
@ignored_options = []
|
||||||
|
|
||||||
|
configure
|
||||||
|
end
|
||||||
|
|
||||||
|
def help
|
||||||
|
@options
|
||||||
|
.collect_concat {|name, (type, value)|
|
||||||
|
option = option_name(name)
|
||||||
|
if type == :bool
|
||||||
|
["--enable-#{option}", "--disable-#{option}"]
|
||||||
|
else
|
||||||
|
"--#{option}=#{type.upcase}"
|
||||||
|
end
|
||||||
|
}
|
||||||
|
.join($/)
|
||||||
|
end
|
||||||
|
|
||||||
|
def to_s
|
||||||
|
@options
|
||||||
|
.reject {|name, (type, value)| value.nil?}
|
||||||
|
.collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"}
|
||||||
|
.join(" ")
|
||||||
|
end
|
||||||
|
|
||||||
|
def cmake_options
|
||||||
|
return @cmake_options if @cmake_options
|
||||||
|
|
||||||
|
output = nil
|
||||||
|
Dir.chdir __dir__ do
|
||||||
|
output = `cmake -S sources -B build -L`
|
||||||
|
end
|
||||||
|
started = false
|
||||||
|
@cmake_options = output.lines.filter_map {|line|
|
||||||
|
if line.chomp == "-- Cache values"
|
||||||
|
started = true
|
||||||
|
next
|
||||||
|
end
|
||||||
|
next unless started
|
||||||
|
option, value = line.chomp.split("=", 2)
|
||||||
|
name, type = option.split(":", 2)
|
||||||
|
[name, type, value]
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def missing_options
|
||||||
|
cmake_options.collect {|name, type, value| name} -
|
||||||
|
@options.keys - @pending_options - @ignored_options
|
||||||
|
end
|
||||||
|
|
||||||
|
def extra_options
|
||||||
|
@options.keys + @pending_options - @ignored_options -
|
||||||
|
cmake_options.collect {|name, type, value| name}
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def configure
|
||||||
|
filepath "ACCELERATE_FRAMEWORK"
|
||||||
|
ignored "BUILD_SHARED_LIBS"
|
||||||
|
ignored "BUILD_TESTING"
|
||||||
|
ignored "CMAKE_BUILD_TYPE"
|
||||||
|
ignored "CMAKE_INSTALL_PREFIX"
|
||||||
|
string "CMAKE_OSX_ARCHITECTURES"
|
||||||
|
ignored "CMAKE_OSX_DEPLOYMENT_TARGET"
|
||||||
|
string "CMAKE_OSX_SYSROOT"
|
||||||
|
filepath "FOUNDATION_LIBRARY"
|
||||||
|
bool "GGML_ACCELERATE"
|
||||||
|
bool "GGML_ALL_WARNINGS_3RD_PARTY"
|
||||||
|
bool "GGML_AMX_BF16"
|
||||||
|
bool "GGML_AMX_INT8"
|
||||||
|
bool "GGML_AMX_TILE"
|
||||||
|
bool "GGML_AVX"
|
||||||
|
bool "GGML_AVX2"
|
||||||
|
bool "GGML_AVX512"
|
||||||
|
bool "GGML_AVX512_BF16"
|
||||||
|
bool "GGML_AVX512_VBMI"
|
||||||
|
bool "GGML_AVX512_VNNI"
|
||||||
|
bool "GGML_AVX_VNNI"
|
||||||
|
ignored "GGML_BACKEND_DL"
|
||||||
|
ignored "GGML_BIN_INSTALL_DIR"
|
||||||
|
bool "GGML_BLAS"
|
||||||
|
string "GGML_BLAS_VENDOR"
|
||||||
|
bool "GGML_BMI2"
|
||||||
|
ignored "GGML_BUILD_EXAMPLES"
|
||||||
|
ignored "GGML_BUILD_TESTS"
|
||||||
|
filepath "GGML_CCACHE_FOUND"
|
||||||
|
bool "GGML_CPU"
|
||||||
|
bool "GGML_CPU_AARCH64"
|
||||||
|
ignored "GGML_CPU_ALL_VARIANTS"
|
||||||
|
string "GGML_CPU_ARM_ARCH"
|
||||||
|
bool "GGML_CPU_HBM"
|
||||||
|
bool "GGML_CPU_KLEIDIAI"
|
||||||
|
string "GGML_CPU_POWERPC_CPUTYPE"
|
||||||
|
bool "GGML_CUDA"
|
||||||
|
string "GGML_CUDA_COMPRESSION_MODE"
|
||||||
|
bool "GGML_CUDA_F16"
|
||||||
|
bool "GGML_CUDA_FA"
|
||||||
|
bool "GGML_CUDA_FA_ALL_QUANTS"
|
||||||
|
bool "GGML_CUDA_FORCE_CUBLAS"
|
||||||
|
bool "GGML_CUDA_FORCE_MMQ"
|
||||||
|
ignored "GGML_CUDA_GRAPHS"
|
||||||
|
bool "GGML_CUDA_NO_PEER_COPY"
|
||||||
|
bool "GGML_CUDA_NO_VMM"
|
||||||
|
string "GGML_CUDA_PEER_MAX_BATCH_SIZE"
|
||||||
|
bool "GGML_F16C"
|
||||||
|
bool "GGML_FMA"
|
||||||
|
bool "GGML_GPROF"
|
||||||
|
bool "GGML_HIP"
|
||||||
|
bool "GGML_HIP_GRAPHS"
|
||||||
|
bool "GGML_HIP_NO_VMM"
|
||||||
|
bool "GGML_HIP_ROCWMMA_FATTN"
|
||||||
|
ignored "GGML_INCLUDE_INSTALL_DIR"
|
||||||
|
bool "GGML_KOMPUTE"
|
||||||
|
bool "GGML_LASX"
|
||||||
|
ignored "GGML_LIB_INSTALL_DIR"
|
||||||
|
ignored "GGML_LLAMAFILE"
|
||||||
|
bool "GGML_LSX"
|
||||||
|
bool "GGML_LTO"
|
||||||
|
bool "GGML_METAL"
|
||||||
|
bool "GGML_METAL_EMBED_LIBRARY"
|
||||||
|
string "GGML_METAL_MACOSX_VERSION_MIN"
|
||||||
|
bool "GGML_METAL_NDEBUG"
|
||||||
|
bool "GGML_METAL_SHADER_DEBUG"
|
||||||
|
string "GGML_METAL_STD"
|
||||||
|
bool "GGML_METAL_USE_BF16"
|
||||||
|
bool "GGML_MUSA"
|
||||||
|
bool "GGML_NATIVE"
|
||||||
|
bool "GGML_OPENCL"
|
||||||
|
bool "GGML_OPENCL_EMBED_KERNELS"
|
||||||
|
bool "GGML_OPENCL_PROFILING"
|
||||||
|
string "GGML_OPENCL_TARGET_VERSION"
|
||||||
|
bool "GGML_OPENCL_USE_ADRENO_KERNELS"
|
||||||
|
bool "GGML_OPENMP"
|
||||||
|
bool "GGML_RPC"
|
||||||
|
bool "GGML_RVV"
|
||||||
|
bool "GGML_RV_ZFH"
|
||||||
|
pending "GGML_SCCACHE_FOUND"
|
||||||
|
string "GGML_SCHED_MAX_COPIES"
|
||||||
|
bool "GGML_SSE42"
|
||||||
|
ignored "GGML_STATIC"
|
||||||
|
bool "GGML_SYCL"
|
||||||
|
string "GGML_SYCL_DEVICE_ARCH"
|
||||||
|
bool "GGML_SYCL_F16"
|
||||||
|
bool "GGML_SYCL_GRAPH"
|
||||||
|
string "GGML_SYCL_TARGET"
|
||||||
|
bool "GGML_VULKAN"
|
||||||
|
bool "GGML_VULKAN_CHECK_RESULTS"
|
||||||
|
bool "GGML_VULKAN_DEBUG"
|
||||||
|
bool "GGML_VULKAN_MEMORY_DEBUG"
|
||||||
|
bool "GGML_VULKAN_PERF"
|
||||||
|
ignored "GGML_VULKAN_RUN_TESTS"
|
||||||
|
filepath "GGML_VULKAN_SHADERS_GEN_TOOLCHAIN"
|
||||||
|
bool "GGML_VULKAN_SHADER_DEBUG_INFO"
|
||||||
|
pending "GGML_VULKAN_VALIDATE"
|
||||||
|
bool "GGML_VXE"
|
||||||
|
filepath "GIT_EXE"
|
||||||
|
filepath "MATH_LIBRARY"
|
||||||
|
filepath "METALKIT_FRAMEWORK"
|
||||||
|
filepath "METAL_FRAMEWORK"
|
||||||
|
bool "WHISPER_ALL_WARNINGS"
|
||||||
|
bool "WHISPER_ALL_WARNINGS_3RD_PARTY"
|
||||||
|
ignored "WHISPER_BIN_INSTALL_DIR"
|
||||||
|
ignored "WHISPER_BUILD_EXAMPLES"
|
||||||
|
ignored "WHISPER_BUILD_SERVER"
|
||||||
|
ignored"WHISPER_BUILD_TESTS"
|
||||||
|
bool "WHISPER_CCACHE"
|
||||||
|
bool "WHISPER_COREML"
|
||||||
|
bool "WHISPER_COREML_ALLOW_FALLBACK"
|
||||||
|
ignored "WHISPER_CURL"
|
||||||
|
bool "WHISPER_FATAL_WARNINGS"
|
||||||
|
ignored "WHISPER_FFMPEG"
|
||||||
|
ignored "WHISPER_INCLUDE_INSTALL_DIR"
|
||||||
|
ignored "WHISPER_LIB_INSTALL_DIR"
|
||||||
|
bool "WHISPER_OPENVINO"
|
||||||
|
bool "WHISPER_SANITIZE_ADDRESS"
|
||||||
|
bool "WHISPER_SANITIZE_THREAD"
|
||||||
|
bool "WHISPER_SANITIZE_UNDEFINED"
|
||||||
|
ignored "WHISPER_SDL2"
|
||||||
|
pending "WHISPER_USE_SYSTEM_GGML"
|
||||||
|
end
|
||||||
|
|
||||||
|
def option_name(name)
|
||||||
|
name.downcase.gsub("_", "-")
|
||||||
|
end
|
||||||
|
|
||||||
|
def bool(name)
|
||||||
|
option = option_name(name)
|
||||||
|
value = enable_config(option)
|
||||||
|
@options[name] = [:bool, value]
|
||||||
|
end
|
||||||
|
|
||||||
|
def string(name, type=:string)
|
||||||
|
option = "--#{option_name(name)}"
|
||||||
|
value = arg_config(option)
|
||||||
|
raise "String expected for #{option}" if value == true || value&.empty?
|
||||||
|
@options[name] = [type, value]
|
||||||
|
end
|
||||||
|
|
||||||
|
def path(name)
|
||||||
|
string(name, :path)
|
||||||
|
end
|
||||||
|
|
||||||
|
def filepath(name)
|
||||||
|
string(name, :filepath)
|
||||||
|
end
|
||||||
|
|
||||||
|
def pending(name)
|
||||||
|
@pending_options << name
|
||||||
|
end
|
||||||
|
|
||||||
|
def ignored(name)
|
||||||
|
@ignored_options << name
|
||||||
|
end
|
||||||
|
end
|
@ -19,6 +19,7 @@ typedef struct {
|
|||||||
bool diarize;
|
bool diarize;
|
||||||
ruby_whisper_callback_container *new_segment_callback_container;
|
ruby_whisper_callback_container *new_segment_callback_container;
|
||||||
ruby_whisper_callback_container *progress_callback_container;
|
ruby_whisper_callback_container *progress_callback_container;
|
||||||
|
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||||
ruby_whisper_callback_container *abort_callback_container;
|
ruby_whisper_callback_container *abort_callback_container;
|
||||||
} ruby_whisper_params;
|
} ruby_whisper_params;
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@
|
|||||||
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
|
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
|
||||||
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
|
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
|
||||||
|
|
||||||
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
|
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
|
||||||
|
|
||||||
extern VALUE cParams;
|
extern VALUE cParams;
|
||||||
|
|
||||||
@ -63,6 +63,8 @@ static ID id_new_segment_callback;
|
|||||||
static ID id_new_segment_callback_user_data;
|
static ID id_new_segment_callback_user_data;
|
||||||
static ID id_progress_callback;
|
static ID id_progress_callback;
|
||||||
static ID id_progress_callback_user_data;
|
static ID id_progress_callback_user_data;
|
||||||
|
static ID id_encoder_begin_callback;
|
||||||
|
static ID id_encoder_begin_callback_user_data;
|
||||||
static ID id_abort_callback;
|
static ID id_abort_callback;
|
||||||
static ID id_abort_callback_user_data;
|
static ID id_abort_callback_user_data;
|
||||||
|
|
||||||
@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
|
||||||
|
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||||
|
bool is_aborted = false;
|
||||||
|
VALUE result;
|
||||||
|
|
||||||
|
// Currently, doesn't support state because
|
||||||
|
// those require to resolve GC-related problems.
|
||||||
|
if (!NIL_P(container->callback)) {
|
||||||
|
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
|
||||||
|
if (result == Qfalse) {
|
||||||
|
is_aborted = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const long callbacks_len = RARRAY_LEN(container->callbacks);
|
||||||
|
if (0 == callbacks_len) {
|
||||||
|
return !is_aborted;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < callbacks_len; j++) {
|
||||||
|
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||||
|
result = rb_funcall(cb, id_call, 0);
|
||||||
|
if (result == Qfalse) {
|
||||||
|
is_aborted = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return !is_aborted;
|
||||||
|
}
|
||||||
|
|
||||||
static bool abort_callback(void * user_data) {
|
static bool abort_callback(void * user_data) {
|
||||||
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
|
||||||
if (!NIL_P(container->callback)) {
|
if (!NIL_P(container->callback)) {
|
||||||
@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
|||||||
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
|
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
|
||||||
|
rwp->encoder_begin_callback_container->context = context;
|
||||||
|
rwp->params.encoder_begin_callback = encoder_begin_callback;
|
||||||
|
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
|
||||||
|
}
|
||||||
|
|
||||||
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
|
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
|
||||||
rwp->abort_callback_container->context = context;
|
rwp->abort_callback_container->context = context;
|
||||||
rwp->params.abort_callback = abort_callback;
|
rwp->params.abort_callback = abort_callback;
|
||||||
@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
|
|||||||
{
|
{
|
||||||
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
|
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
|
||||||
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
|
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
|
||||||
|
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
|
||||||
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
|
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass)
|
|||||||
rwp->diarize = false;
|
rwp->diarize = false;
|
||||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||||
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
||||||
|
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
|
||||||
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
|
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
|
||||||
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
||||||
}
|
}
|
||||||
@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
|
|||||||
rwp->progress_callback_container->user_data = value;
|
rwp->progress_callback_container->user_data = value;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static VALUE
|
||||||
|
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
|
||||||
|
{
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return rwp->encoder_begin_callback_container->callback;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Sets encoder begin callback, called when the encoder starts.
|
||||||
|
*
|
||||||
|
* params.encoder_begin_callback = ->(context, _, user_data) {
|
||||||
|
* # ...
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* encoder_begin_callback = callback -> callback
|
||||||
|
*/
|
||||||
|
static VALUE
|
||||||
|
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
|
||||||
|
{
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->encoder_begin_callback_container->callback = value;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static VALUE
|
||||||
|
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
|
||||||
|
{
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
return rwp->encoder_begin_callback_container->user_data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Sets user data passed to the last argument of encoder begin callback.
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* encoder_begin_callback_user_data = user_data -> use_data
|
||||||
|
*/
|
||||||
|
static VALUE
|
||||||
|
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
|
||||||
|
{
|
||||||
|
ruby_whisper_params *rwp;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
rwp->encoder_begin_callback_container->user_data = value;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
static VALUE
|
static VALUE
|
||||||
ruby_whisper_params_get_abort_callback(VALUE self)
|
ruby_whisper_params_get_abort_callback(VALUE self)
|
||||||
{
|
{
|
||||||
@ -918,7 +1006,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
|||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
rb_get_kwargs(kw_hash, ¶m_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
|
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values);
|
||||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||||
|
|
||||||
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
|
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
|
||||||
@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
|||||||
SET_PARAM_IF_SAME(new_segment_callback_user_data)
|
SET_PARAM_IF_SAME(new_segment_callback_user_data)
|
||||||
SET_PARAM_IF_SAME(progress_callback)
|
SET_PARAM_IF_SAME(progress_callback)
|
||||||
SET_PARAM_IF_SAME(progress_callback_user_data)
|
SET_PARAM_IF_SAME(progress_callback_user_data)
|
||||||
|
SET_PARAM_IF_SAME(encoder_begin_callback)
|
||||||
|
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
|
||||||
SET_PARAM_IF_SAME(abort_callback)
|
SET_PARAM_IF_SAME(abort_callback)
|
||||||
SET_PARAM_IF_SAME(abort_callback_user_data)
|
SET_PARAM_IF_SAME(abort_callback_user_data)
|
||||||
}
|
}
|
||||||
@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self)
|
|||||||
return Qnil;
|
return Qnil;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Hook called when the encoder starts.
|
||||||
|
*
|
||||||
|
* whisper.on_encoder_begin do
|
||||||
|
* # ...
|
||||||
|
* end
|
||||||
|
*
|
||||||
|
* call-seq:
|
||||||
|
* on_encoder_begin { ... }
|
||||||
|
*/
|
||||||
|
static VALUE
|
||||||
|
ruby_whisper_params_on_encoder_begin(VALUE self)
|
||||||
|
{
|
||||||
|
ruby_whisper_params *rws;
|
||||||
|
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||||
|
const VALUE blk = rb_block_proc();
|
||||||
|
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
|
||||||
|
return Qnil;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Call block to determine whether abort or not. Return +true+ when you want to abort.
|
* Call block to determine whether abort or not. Return +true+ when you want to abort.
|
||||||
*
|
*
|
||||||
@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper)
|
|||||||
DEFINE_PARAM(new_segment_callback_user_data, 25)
|
DEFINE_PARAM(new_segment_callback_user_data, 25)
|
||||||
DEFINE_PARAM(progress_callback, 26)
|
DEFINE_PARAM(progress_callback, 26)
|
||||||
DEFINE_PARAM(progress_callback_user_data, 27)
|
DEFINE_PARAM(progress_callback_user_data, 27)
|
||||||
DEFINE_PARAM(abort_callback, 28)
|
DEFINE_PARAM(encoder_begin_callback, 28)
|
||||||
DEFINE_PARAM(abort_callback_user_data, 29)
|
DEFINE_PARAM(encoder_begin_callback_user_data, 29)
|
||||||
|
DEFINE_PARAM(abort_callback, 30)
|
||||||
|
DEFINE_PARAM(abort_callback_user_data, 31)
|
||||||
|
|
||||||
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
|
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
|
||||||
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
|
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
|
||||||
|
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
|
||||||
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
|
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
|
||||||
}
|
}
|
||||||
|
@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
|||||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
{
|
// Commented out because it is work in progress
|
||||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
// {
|
||||||
|
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||||
|
|
||||||
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||||
bool is_aborted = *(bool*)user_data;
|
// bool is_aborted = *(bool*)user_data;
|
||||||
return !is_aborted;
|
// return !is_aborted;
|
||||||
};
|
// };
|
||||||
rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||||
}
|
// }
|
||||||
|
|
||||||
register_callbacks(rwp, &self);
|
register_callbacks(rwp, &self);
|
||||||
|
|
||||||
|
8
bindings/ruby/ext/sources/CMakeGraphVizOptions.cmake
Normal file
8
bindings/ruby/ext/sources/CMakeGraphVizOptions.cmake
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
set(GRAPHVIZ_EXECUTABLES FALSE)
|
||||||
|
set(GRAPHVIZ_STATIC_LIBS TRUE)
|
||||||
|
set(GRAPHVIZ_SHARED_LIBS FALSE)
|
||||||
|
set(GRAPHVIZ_MODULE_LIBS FALSE)
|
||||||
|
set(GRAPHVIZ_INTERFACE_LIBS FALSE)
|
||||||
|
set(GRAPHVIZ_OBJECT_LIBS FALSE)
|
||||||
|
set(GRAPHVIZ_UNKNOWN_LIBS FALSE)
|
||||||
|
set(GRAPHVIZ_GENERATE_DEPENDERS FALSE)
|
@ -1,6 +1,34 @@
|
|||||||
require "yaml"
|
ignored_dirs = %w[
|
||||||
|
.devops
|
||||||
|
examples/wchess/wchess.wasm
|
||||||
|
examples/whisper.android
|
||||||
|
examples/whisper.android.java
|
||||||
|
examples/whisper.objc
|
||||||
|
examples/whisper.swiftui
|
||||||
|
grammars
|
||||||
|
models
|
||||||
|
samples
|
||||||
|
scripts
|
||||||
|
]
|
||||||
|
ignored_files = %w[
|
||||||
|
AUTHORS
|
||||||
|
Makefile
|
||||||
|
README.md
|
||||||
|
README_sycl.md
|
||||||
|
.gitignore
|
||||||
|
.gitmodules
|
||||||
|
whisper.nvim
|
||||||
|
twitch.sh
|
||||||
|
yt-wsp.sh
|
||||||
|
]
|
||||||
|
|
||||||
sources = `git ls-files -z ../..`.split("\x0")
|
EXTSOURCES =
|
||||||
paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
|
`git ls-files -z ../..`.split("\x0")
|
||||||
paths.delete "bindings/ruby/**"
|
.select {|file|
|
||||||
EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
|
basename = File.basename(file)
|
||||||
|
|
||||||
|
ignored_dirs.all? {|dir| !file.start_with?("../../#{dir}")} &&
|
||||||
|
!ignored_files.include?(basename) &&
|
||||||
|
(file.start_with?("../..") || file.start_with?("../javascript")) &&
|
||||||
|
(!file.start_with?("../../.github/") || basename == "bindings-ruby.yml")
|
||||||
|
}
|
||||||
|
@ -34,7 +34,7 @@ module Whisper
|
|||||||
when /darwin/
|
when /darwin/
|
||||||
Pathname(Dir.home)/"Library/Caches"
|
Pathname(Dir.home)/"Library/Caches"
|
||||||
else
|
else
|
||||||
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
|
ENV.key?("XDG_CACHE_HOME") ? Pathname(ENV["XDG_CACHE_HOME"]) : Pathname(Dir.home)/".cache"
|
||||||
end
|
end
|
||||||
base/"whisper.cpp"
|
base/"whisper.cpp"
|
||||||
end
|
end
|
||||||
@ -53,8 +53,10 @@ module Whisper
|
|||||||
http.request request do |response|
|
http.request request do |response|
|
||||||
case response
|
case response
|
||||||
when Net::HTTPNotModified
|
when Net::HTTPNotModified
|
||||||
# noop
|
# noop
|
||||||
when Net::HTTPOK
|
when Net::HTTPOK
|
||||||
|
return if !response.key?("last-modified") && cache_path.exist?
|
||||||
|
|
||||||
download response
|
download response
|
||||||
when Net::HTTPRedirection
|
when Net::HTTPRedirection
|
||||||
request URI(response["location"]), headers
|
request URI(response["location"]), headers
|
||||||
@ -68,7 +70,7 @@ module Whisper
|
|||||||
rescue => err
|
rescue => err
|
||||||
if cache_path.exist?
|
if cache_path.exist?
|
||||||
warn err
|
warn err
|
||||||
# Use cache file
|
# Use cache file
|
||||||
else
|
else
|
||||||
raise
|
raise
|
||||||
end
|
end
|
||||||
|
@ -7,6 +7,7 @@ module Whisper
|
|||||||
type log_callback = ^(Integer level, String message, Object user_data) -> void
|
type log_callback = ^(Integer level, String message, Object user_data) -> void
|
||||||
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
|
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
|
||||||
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
|
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
|
||||||
|
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
|
||||||
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
|
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
|
||||||
|
|
||||||
LOG_LEVEL_NONE: Integer
|
LOG_LEVEL_NONE: Integer
|
||||||
@ -23,9 +24,20 @@ module Whisper
|
|||||||
def self.log_set: (log_callback, Object? user_data) -> log_callback
|
def self.log_set: (log_callback, Object? user_data) -> log_callback
|
||||||
|
|
||||||
class Context
|
class Context
|
||||||
def self.new: (string | _ToPath | ::URI::HTTP) -> instance
|
def self.new: (path | ::URI::HTTP) -> instance
|
||||||
|
|
||||||
|
# transcribe a single file
|
||||||
|
# can emit to a block results
|
||||||
|
#
|
||||||
|
# params = Whisper::Params.new
|
||||||
|
# params.duration = 60_000
|
||||||
|
# whisper.transcribe "path/to/audio.wav", params do |text|
|
||||||
|
# puts text
|
||||||
|
# end
|
||||||
|
#
|
||||||
def transcribe: (string, Params) -> self
|
def transcribe: (string, Params) -> self
|
||||||
| (string, Params) { (String) -> void } -> self
|
| (string, Params) { (String) -> void } -> self
|
||||||
|
|
||||||
def model_n_vocab: () -> Integer
|
def model_n_vocab: () -> Integer
|
||||||
def model_n_audio_ctx: () -> Integer
|
def model_n_audio_ctx: () -> Integer
|
||||||
def model_n_audio_state: () -> Integer
|
def model_n_audio_state: () -> Integer
|
||||||
@ -34,19 +46,72 @@ module Whisper
|
|||||||
def model_n_mels: () -> Integer
|
def model_n_mels: () -> Integer
|
||||||
def model_ftype: () -> Integer
|
def model_ftype: () -> Integer
|
||||||
def model_type: () -> String
|
def model_type: () -> String
|
||||||
|
|
||||||
|
# Yields each Whisper::Segment:
|
||||||
|
#
|
||||||
|
# whisper.transcribe("path/to/audio.wav", params)
|
||||||
|
# whisper.each_segment do |segment|
|
||||||
|
# puts segment.text
|
||||||
|
# end
|
||||||
|
#
|
||||||
|
# Returns an Enumerator if no block given:
|
||||||
|
#
|
||||||
|
# whisper.transcribe("path/to/audio.wav", params)
|
||||||
|
# enum = whisper.each_segment
|
||||||
|
# enum.to_a # => [#<Whisper::Segment>, ...]
|
||||||
|
#
|
||||||
def each_segment: { (Segment) -> void } -> void
|
def each_segment: { (Segment) -> void } -> void
|
||||||
| () -> Enumerator[Segment]
|
| () -> Enumerator[Segment]
|
||||||
|
|
||||||
def model: () -> Model
|
def model: () -> Model
|
||||||
def full_get_segment: (Integer nth) -> Segment
|
def full_get_segment: (Integer nth) -> Segment
|
||||||
def full_n_segments: () -> Integer
|
def full_n_segments: () -> Integer
|
||||||
|
|
||||||
|
# Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
|
||||||
|
#
|
||||||
def full_lang_id: () -> Integer
|
def full_lang_id: () -> Integer
|
||||||
|
|
||||||
|
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
|
||||||
|
#
|
||||||
|
# full_get_segment_t0(3) # => 1668 (16680 ms)
|
||||||
|
#
|
||||||
def full_get_segment_t0: (Integer) -> Integer
|
def full_get_segment_t0: (Integer) -> Integer
|
||||||
|
|
||||||
|
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
|
||||||
|
#
|
||||||
|
# full_get_segment_t1(3) # => 1668 (16680 ms)
|
||||||
|
#
|
||||||
def full_get_segment_t1: (Integer) -> Integer
|
def full_get_segment_t1: (Integer) -> Integer
|
||||||
|
|
||||||
|
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
|
||||||
|
#
|
||||||
|
# full_get_segment_speacker_turn_next(3) # => true
|
||||||
|
#
|
||||||
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
|
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
|
||||||
|
|
||||||
|
# Text of a segment indexed by +segment_index+.
|
||||||
|
#
|
||||||
|
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
|
||||||
|
#
|
||||||
def full_get_segment_text: (Integer) -> String
|
def full_get_segment_text: (Integer) -> String
|
||||||
|
|
||||||
def full_get_segment_no_speech_prob: (Integer) -> Float
|
def full_get_segment_no_speech_prob: (Integer) -> Float
|
||||||
|
|
||||||
|
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
|
# Not thread safe for same context
|
||||||
|
# Uses the specified decoding strategy to obtain the text.
|
||||||
|
#
|
||||||
|
# The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
|
||||||
|
#
|
||||||
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
|
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
|
||||||
| (Params, _Samples, ?Integer n_samples) -> self
|
| (Params, _Samples, ?Integer n_samples) -> self
|
||||||
|
|
||||||
|
# Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
||||||
|
# Result is stored in the default state of the context
|
||||||
|
# Not thread safe if executed in parallel on the same context.
|
||||||
|
# It seems this approach can offer some speedup in some cases.
|
||||||
|
# However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||||
|
#
|
||||||
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
|
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
|
||||||
| (Params, _Samples, ?Integer n_samples) -> self
|
| (Params, _Samples, ?Integer n_samples) -> self
|
||||||
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
||||||
@ -82,71 +147,223 @@ module Whisper
|
|||||||
?new_segment_callback_user_data: Object,
|
?new_segment_callback_user_data: Object,
|
||||||
?progress_callback: progress_callback,
|
?progress_callback: progress_callback,
|
||||||
?progress_callback_user_data: Object,
|
?progress_callback_user_data: Object,
|
||||||
|
?encoder_begin_callback: encoder_begin_callback,
|
||||||
|
?encoder_begin_callback_user_data: Object,
|
||||||
?abort_callback: abort_callback,
|
?abort_callback: abort_callback,
|
||||||
?abort_callback_user_data: Object
|
?abort_callback_user_data: Object
|
||||||
) -> instance
|
) -> instance
|
||||||
|
|
||||||
|
# params.language = "auto" | "en", etc...
|
||||||
|
#
|
||||||
def language=: (String) -> String # TODO: Enumerate lang names
|
def language=: (String) -> String # TODO: Enumerate lang names
|
||||||
|
|
||||||
def language: () -> String
|
def language: () -> String
|
||||||
def translate=: (boolish) -> boolish
|
def translate=: (boolish) -> boolish
|
||||||
def translate: () -> (true | false)
|
def translate: () -> (true | false)
|
||||||
def no_context=: (boolish) -> boolish
|
def no_context=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, does not use past transcription (if any) as initial prompt for the decoder.
|
||||||
|
#
|
||||||
def no_context: () -> (true | false)
|
def no_context: () -> (true | false)
|
||||||
|
|
||||||
def single_segment=: (boolish) -> boolish
|
def single_segment=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, forces single segment output (useful for streaming).
|
||||||
|
#
|
||||||
def single_segment: () -> (true | false)
|
def single_segment: () -> (true | false)
|
||||||
|
|
||||||
def print_special=: (boolish) -> boolish
|
def print_special=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
|
||||||
|
#
|
||||||
def print_special: () -> (true | false)
|
def print_special: () -> (true | false)
|
||||||
|
|
||||||
def print_progress=: (boolish) -> boolish
|
def print_progress=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, prints progress information.
|
||||||
|
#
|
||||||
def print_progress: () -> (true | false)
|
def print_progress: () -> (true | false)
|
||||||
|
|
||||||
def print_realtime=: (boolish) -> boolish
|
def print_realtime=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, prints results from within whisper.cpp. (avoid it, use callback instead)
|
||||||
|
#
|
||||||
def print_realtime: () -> (true | false)
|
def print_realtime: () -> (true | false)
|
||||||
|
|
||||||
|
# If true, prints timestamps for each text segment when printing realtime.
|
||||||
|
#
|
||||||
def print_timestamps=: (boolish) -> boolish
|
def print_timestamps=: (boolish) -> boolish
|
||||||
|
|
||||||
def print_timestamps: () -> (true | false)
|
def print_timestamps: () -> (true | false)
|
||||||
|
|
||||||
def suppress_blank=: (boolish) -> boolish
|
def suppress_blank=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, suppresses blank outputs.
|
||||||
|
#
|
||||||
def suppress_blank: () -> (true | false)
|
def suppress_blank: () -> (true | false)
|
||||||
|
|
||||||
def suppress_nst=: (boolish) -> boolish
|
def suppress_nst=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, suppresses non-speech-tokens.
|
||||||
|
#
|
||||||
def suppress_nst: () -> (true | false)
|
def suppress_nst: () -> (true | false)
|
||||||
|
|
||||||
def token_timestamps=: (boolish) -> boolish
|
def token_timestamps=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, enables token-level timestamps.
|
||||||
|
#
|
||||||
def token_timestamps: () -> (true | false)
|
def token_timestamps: () -> (true | false)
|
||||||
|
|
||||||
def split_on_word=: (boolish) -> boolish
|
def split_on_word=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, split on word rather than on token (when used with max_len).
|
||||||
|
#
|
||||||
def split_on_word: () -> (true | false)
|
def split_on_word: () -> (true | false)
|
||||||
|
|
||||||
def initial_prompt=: (_ToS) -> _ToS
|
def initial_prompt=: (_ToS) -> _ToS
|
||||||
|
|
||||||
|
# Tokens to provide to the whisper decoder as initial prompt
|
||||||
|
# these are prepended to any existing text context from a previous call
|
||||||
|
# use whisper_tokenize() to convert text to tokens.
|
||||||
|
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
|
||||||
|
#
|
||||||
def initial_prompt: () -> (String | nil)
|
def initial_prompt: () -> (String | nil)
|
||||||
|
|
||||||
def diarize=: (boolish) -> boolish
|
def diarize=: (boolish) -> boolish
|
||||||
|
|
||||||
|
# If true, enables diarization.
|
||||||
|
#
|
||||||
def diarize: () -> (true | false)
|
def diarize: () -> (true | false)
|
||||||
|
|
||||||
def offset=: (Integer) -> Integer
|
def offset=: (Integer) -> Integer
|
||||||
|
|
||||||
|
# Start offset in ms.
|
||||||
|
#
|
||||||
def offset: () -> Integer
|
def offset: () -> Integer
|
||||||
|
|
||||||
def duration=: (Integer) -> Integer
|
def duration=: (Integer) -> Integer
|
||||||
|
|
||||||
|
# Audio duration to process in ms.
|
||||||
|
#
|
||||||
def duration: () -> Integer
|
def duration: () -> Integer
|
||||||
|
|
||||||
def max_text_tokens=: (Integer) -> Integer
|
def max_text_tokens=: (Integer) -> Integer
|
||||||
|
|
||||||
|
# Max tokens to use from past text as prompt for the decoder.
|
||||||
|
#
|
||||||
def max_text_tokens: () -> Integer
|
def max_text_tokens: () -> Integer
|
||||||
|
|
||||||
def temperature=: (Float) -> Float
|
def temperature=: (Float) -> Float
|
||||||
def temperature: () -> Float
|
def temperature: () -> Float
|
||||||
def max_initial_ts=: (Float) -> Float
|
def max_initial_ts=: (Float) -> Float
|
||||||
|
|
||||||
|
# See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
||||||
|
#
|
||||||
def max_initial_ts: () -> Float
|
def max_initial_ts: () -> Float
|
||||||
|
|
||||||
def length_penalty=: (Float) -> Float
|
def length_penalty=: (Float) -> Float
|
||||||
def length_penalty: () -> Float
|
def length_penalty: () -> Float
|
||||||
def temperature_inc=: (Float) -> Float
|
def temperature_inc=: (Float) -> Float
|
||||||
def temperature_inc: () -> Float
|
def temperature_inc: () -> Float
|
||||||
def entropy_thold=: (Float) -> Float
|
def entropy_thold=: (Float) -> Float
|
||||||
|
|
||||||
|
# Similar to OpenAI's "compression_ratio_threshold"
|
||||||
|
#
|
||||||
def entropy_thold: () -> Float
|
def entropy_thold: () -> Float
|
||||||
|
|
||||||
def logprob_thold=: (Float) -> Float
|
def logprob_thold=: (Float) -> Float
|
||||||
def logprob_thold: () -> Float
|
def logprob_thold: () -> Float
|
||||||
def no_speech_thold=: (Float) -> Float
|
def no_speech_thold=: (Float) -> Float
|
||||||
def no_speech_thold: () -> Float
|
def no_speech_thold: () -> Float
|
||||||
|
|
||||||
|
# Sets new segment callback, called for every newly generated text segment.
|
||||||
|
#
|
||||||
|
# params.new_segment_callback = ->(context, _, n_new, user_data) {
|
||||||
|
# # ...
|
||||||
|
# }
|
||||||
|
#
|
||||||
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
|
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
|
||||||
def new_segment_callback: () -> (new_segment_callback | nil)
|
def new_segment_callback: () -> (new_segment_callback | nil)
|
||||||
|
|
||||||
|
# Sets user data passed to the last argument of new segment callback.
|
||||||
|
#
|
||||||
def new_segment_callback_user_data=: (Object) -> Object
|
def new_segment_callback_user_data=: (Object) -> Object
|
||||||
|
|
||||||
def new_segment_callback_user_data: () -> Object
|
def new_segment_callback_user_data: () -> Object
|
||||||
|
|
||||||
|
# Sets progress callback, called on each progress update.
|
||||||
|
#
|
||||||
|
# params.new_segment_callback = ->(context, _, progress, user_data) {
|
||||||
|
# # ...
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# +progress+ is an Integer between 0 and 100.
|
||||||
|
#
|
||||||
def progress_callback=: (progress_callback) -> progress_callback
|
def progress_callback=: (progress_callback) -> progress_callback
|
||||||
|
|
||||||
def progress_callback: () -> (progress_callback | nil)
|
def progress_callback: () -> (progress_callback | nil)
|
||||||
|
|
||||||
|
# Sets user data passed to the last argument of progress callback.
|
||||||
|
#
|
||||||
def progress_callback_user_data=: (Object) -> Object
|
def progress_callback_user_data=: (Object) -> Object
|
||||||
|
|
||||||
def progress_callback_user_data: () -> Object
|
def progress_callback_user_data: () -> Object
|
||||||
|
|
||||||
|
# Sets encoder begin callback, called when the encoder starts.
|
||||||
|
#
|
||||||
|
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
|
||||||
|
|
||||||
|
def encoder_begin_callback: () -> (encoder_begin_callback | nil)
|
||||||
|
|
||||||
|
# Sets user data passed to the last argument of encoder begin callback.
|
||||||
|
#
|
||||||
|
def encoder_begin_callback_user_data=: (Object) -> Object
|
||||||
|
|
||||||
|
def encoder_begin_callback_user_data: () -> Object
|
||||||
|
|
||||||
|
# Sets abort callback, called to check if the process should be aborted.
|
||||||
|
#
|
||||||
|
# params.abort_callback = ->(user_data) {
|
||||||
|
# # ...
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
#
|
||||||
def abort_callback=: (abort_callback) -> abort_callback
|
def abort_callback=: (abort_callback) -> abort_callback
|
||||||
|
|
||||||
def abort_callback: () -> (abort_callback | nil)
|
def abort_callback: () -> (abort_callback | nil)
|
||||||
|
|
||||||
|
# Sets user data passed to the last argument of abort callback.
|
||||||
|
#
|
||||||
def abort_callback_user_data=: (Object) -> Object
|
def abort_callback_user_data=: (Object) -> Object
|
||||||
|
|
||||||
def abort_callback_user_data: () -> Object
|
def abort_callback_user_data: () -> Object
|
||||||
|
|
||||||
|
# Hook called on new segment. Yields each Whisper::Segment.
|
||||||
|
#
|
||||||
|
# whisper.on_new_segment do |segment|
|
||||||
|
# # ...
|
||||||
|
# end
|
||||||
|
#
|
||||||
def on_new_segment: { (Segment) -> void } -> void
|
def on_new_segment: { (Segment) -> void } -> void
|
||||||
|
|
||||||
|
# Hook called on progress update. Yields each progress Integer between 0 and 100.
|
||||||
|
#
|
||||||
def on_progress: { (Integer progress) -> void } -> void
|
def on_progress: { (Integer progress) -> void } -> void
|
||||||
|
|
||||||
|
# Hook called on encoder starts.
|
||||||
|
#
|
||||||
|
def on_encoder_begin: { () -> void } -> void
|
||||||
|
|
||||||
|
# Call block to determine whether abort or not. Return +true+ when you want to abort.
|
||||||
|
#
|
||||||
|
# params.abort_on do
|
||||||
|
# if some_condition
|
||||||
|
# true # abort
|
||||||
|
# else
|
||||||
|
# false # continue
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
#
|
||||||
def abort_on: { (Object user_data) -> boolish } -> void
|
def abort_on: { (Object user_data) -> boolish } -> void
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -167,16 +384,24 @@ module Whisper
|
|||||||
def type: () -> String
|
def type: () -> String
|
||||||
|
|
||||||
class URI
|
class URI
|
||||||
def self.new: (string | ::URI::HTTP) -> self
|
def self.new: (string | ::URI::HTTP) -> instance
|
||||||
def to_path: -> String
|
def to_path: -> String
|
||||||
def clear_cache: -> void
|
def clear_cache: -> void
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
class Segment
|
class Segment
|
||||||
|
# Start time in milliseconds.
|
||||||
|
#
|
||||||
def start_time: () -> Integer
|
def start_time: () -> Integer
|
||||||
|
|
||||||
|
# End time in milliseconds.
|
||||||
|
#
|
||||||
def end_time: () -> Integer
|
def end_time: () -> Integer
|
||||||
|
|
||||||
|
# Whether the next segment is predicted as a speaker turn.
|
||||||
def speaker_next_turn?: () -> (true | false)
|
def speaker_next_turn?: () -> (true | false)
|
||||||
|
|
||||||
def text: () -> String
|
def text: () -> String
|
||||||
def no_speech_prob: () -> Float
|
def no_speech_prob: () -> Float
|
||||||
end
|
end
|
||||||
|
@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
|
|||||||
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
|
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
|
||||||
|
|
||||||
class << self
|
class << self
|
||||||
attr_reader :whisper
|
def whisper
|
||||||
|
return @whisper if @whisper
|
||||||
|
|
||||||
def startup
|
|
||||||
@whisper = Whisper::Context.new("base.en")
|
@whisper = Whisper::Context.new("base.en")
|
||||||
params = Whisper::Params.new
|
params = Whisper::Params.new
|
||||||
params.print_timestamps = false
|
params.print_timestamps = false
|
||||||
@ -21,4 +21,15 @@ class TestBase < Test::Unit::TestCase
|
|||||||
def whisper
|
def whisper
|
||||||
self.class.whisper
|
self.class.whisper
|
||||||
end
|
end
|
||||||
|
|
||||||
|
module BuildOptions
|
||||||
|
load "ext/options.rb", self
|
||||||
|
Options.include self
|
||||||
|
|
||||||
|
def enable_config(name)
|
||||||
|
end
|
||||||
|
|
||||||
|
def arg_config(name)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
@ -111,6 +111,48 @@ class TestCallback < TestBase
|
|||||||
assert_equal 100, last
|
assert_equal 100, last
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def test_encoder_begin_callback
|
||||||
|
i = 0
|
||||||
|
@params.encoder_begin_callback = ->(context, state, user_data) {
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert i > 0
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_encoder_begin_callback_abort
|
||||||
|
logs = []
|
||||||
|
Whisper.log_set -> (level, buffer, user_data) {
|
||||||
|
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
|
||||||
|
}, logs
|
||||||
|
@params.encoder_begin_callback = ->(context, state, user_data) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
|
||||||
|
Whisper.log_set ->(level, buffer, user_data) {}, nil
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_encoder_begin_callback_user_data
|
||||||
|
udata = Object.new
|
||||||
|
@params.encoder_begin_callback_user_data = udata
|
||||||
|
yielded = nil
|
||||||
|
@params.encoder_begin_callback = ->(context, state, user_data) {
|
||||||
|
yielded = user_data
|
||||||
|
}
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert_same udata, yielded
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_on_encoder_begin
|
||||||
|
i = 0
|
||||||
|
@params.on_encoder_begin do
|
||||||
|
i += 1
|
||||||
|
end
|
||||||
|
@whisper.transcribe(@audio, @params)
|
||||||
|
assert i > 0
|
||||||
|
end
|
||||||
|
|
||||||
def test_abort_callback
|
def test_abort_callback
|
||||||
i = 0
|
i = 0
|
||||||
@params.abort_callback = ->(user_data) {
|
@params.abort_callback = ->(user_data) {
|
||||||
|
@ -21,11 +21,26 @@ class TestPackage < TestBase
|
|||||||
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
|
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
|
||||||
filename = match_data[1]
|
filename = match_data[1]
|
||||||
version = match_data[2]
|
version = match_data[2]
|
||||||
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
|
|
||||||
Dir.mktmpdir do |dir|
|
Dir.mktmpdir do |dir|
|
||||||
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
|
||||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
|
assert_installed dir, version
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def assert_installed(dir, version)
|
||||||
|
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
|
||||||
|
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
|
||||||
|
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def test_build_options
|
||||||
|
options = BuildOptions::Options.new
|
||||||
|
assert_empty options.missing_options
|
||||||
|
unless ENV["CI"]
|
||||||
|
assert_empty options.extra_options
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -3,8 +3,8 @@ require_relative "extsources"
|
|||||||
Gem::Specification.new do |s|
|
Gem::Specification.new do |s|
|
||||||
s.name = "whispercpp"
|
s.name = "whispercpp"
|
||||||
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
|
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
|
||||||
s.version = '1.3.1'
|
s.version = '1.3.2'
|
||||||
s.date = '2024-12-19'
|
s.date = '2025-05-01'
|
||||||
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
|
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
|
||||||
s.email = 'todd.fisher@gmail.com'
|
s.email = 'todd.fisher@gmail.com'
|
||||||
s.extra_rdoc_files = ['LICENSE', 'README.md']
|
s.extra_rdoc_files = ['LICENSE', 'README.md']
|
||||||
@ -15,7 +15,8 @@ Gem::Specification.new do |s|
|
|||||||
if s.extra_rdoc_files.include?(basename)
|
if s.extra_rdoc_files.include?(basename)
|
||||||
basename
|
basename
|
||||||
else
|
else
|
||||||
file.sub("../..", "ext")
|
file.sub("../..", "ext/sources")
|
||||||
|
.sub("../javascript", "ext/sources/bindings/javascript")
|
||||||
end
|
end
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ Gem::Specification.new do |s|
|
|||||||
s.required_ruby_version = '>= 3.1.0'
|
s.required_ruby_version = '>= 3.1.0'
|
||||||
|
|
||||||
#### Documentation and testing.
|
#### Documentation and testing.
|
||||||
s.homepage = 'https://github.com/ggerganov/whisper.cpp'
|
s.homepage = 'https://github.com/ggml-org/whisper.cpp'
|
||||||
s.rdoc_options = ['--main', 'README.md']
|
s.rdoc_options = ['--main', 'README.md']
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +41,11 @@ COMMON_CMAKE_ARGS=(
|
|||||||
-DGGML_OPENMP=${GGML_OPENMP}
|
-DGGML_OPENMP=${GGML_OPENMP}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||||
|
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||||
|
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||||
|
echo "Detected Xcode version: $XCODE_VERSION"
|
||||||
|
|
||||||
check_required_tool() {
|
check_required_tool() {
|
||||||
local tool=$1
|
local tool=$1
|
||||||
local install_message=$2
|
local install_message=$2
|
||||||
@ -335,21 +340,28 @@ combine_static_libraries() {
|
|||||||
|
|
||||||
# Platform-specific post-processing for device builds
|
# Platform-specific post-processing for device builds
|
||||||
if [[ "$is_simulator" == "false" ]]; then
|
if [[ "$is_simulator" == "false" ]]; then
|
||||||
if command -v vtool &>/dev/null; then
|
if command -v xcrun vtool &>/dev/null; then
|
||||||
case "$platform" in
|
case "$platform" in
|
||||||
"ios")
|
"ios")
|
||||||
echo "Marking binary as a framework binary for iOS..."
|
echo "Marking binary as a framework binary for iOS..."
|
||||||
vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
|
xcrun vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
|
||||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
||||||
;;
|
;;
|
||||||
"visionos")
|
"visionos")
|
||||||
echo "Marking binary as a framework binary for visionOS..."
|
echo "Marking binary as a framework binary for visionOS..."
|
||||||
vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
|
if [[ "$MAJOR_VERSION" -gt 16 ]] || [[ "$MAJOR_VERSION" -eq 16 && "$MINOR_VERSION" -gt 2 ]]; then
|
||||||
|
echo "Xcode version greater than 16.2, using visionOS."
|
||||||
|
VISION_OS_BUILD_VERSION="visionos"
|
||||||
|
else
|
||||||
|
echo "Xcode version less than or equal to 16.2, using xros."
|
||||||
|
VISION_OS_BUILD_VERSION="xros"
|
||||||
|
fi
|
||||||
|
xcrun vtool -set-build-version ${VISION_OS_BUILD_VERSION} ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
|
||||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
||||||
;;
|
;;
|
||||||
"tvos")
|
"tvos")
|
||||||
echo "Marking binary as a framework binary for tvOS..."
|
echo "Marking binary as a framework binary for tvOS..."
|
||||||
vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
|
xcrun vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
|
||||||
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
@ -19,6 +19,12 @@ const whisperParamsMock = {
|
|||||||
no_timestamps: false,
|
no_timestamps: false,
|
||||||
audio_ctx: 0,
|
audio_ctx: 0,
|
||||||
max_len: 0,
|
max_len: 0,
|
||||||
|
prompt: "",
|
||||||
|
print_progress: false,
|
||||||
|
progress_callback: (progress) => {
|
||||||
|
console.log(`Progress: ${progress}`);
|
||||||
|
},
|
||||||
|
max_context: -1
|
||||||
};
|
};
|
||||||
|
|
||||||
describe("Run whisper.node", () => {
|
describe("Run whisper.node", () => {
|
||||||
|
@ -368,6 +368,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
|||||||
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
|
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
|
||||||
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
|
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
|
||||||
|
|
||||||
|
// Add support for max_context
|
||||||
|
int32_t max_context = -1;
|
||||||
|
if (whisper_params.Has("max_context") && whisper_params.Get("max_context").IsNumber()) {
|
||||||
|
max_context = whisper_params.Get("max_context").As<Napi::Number>();
|
||||||
|
}
|
||||||
|
|
||||||
// support prompt
|
// support prompt
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
|
if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
|
||||||
@ -407,6 +413,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
|||||||
params.pcmf32 = pcmf32_vec;
|
params.pcmf32 = pcmf32_vec;
|
||||||
params.comma_in_time = comma_in_time;
|
params.comma_in_time = comma_in_time;
|
||||||
params.max_len = max_len;
|
params.max_len = max_len;
|
||||||
|
params.max_context = max_context;
|
||||||
params.print_progress = print_progress;
|
params.print_progress = print_progress;
|
||||||
params.prompt = prompt;
|
params.prompt = prompt;
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ A very basic tool for benchmarking the inference performance on your device. The
|
|||||||
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
|
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
|
||||||
of the performance of the model for various setups.
|
of the performance of the model for various setups.
|
||||||
|
|
||||||
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
|
Benchmark results are tracked in the following Github issue: https://github.com/ggml-org/whisper.cpp/issues/89
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# run the bench too on the small.en model using 4 threads
|
# run the bench too on the small.en model using 4 threads
|
||||||
@ -40,7 +40,7 @@ system_info: n_threads = 4 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WA
|
|||||||
|
|
||||||
If you wish, you can submit these results here:
|
If you wish, you can submit these results here:
|
||||||
|
|
||||||
https://github.com/ggerganov/whisper.cpp/issues/89
|
https://github.com/ggml-org/whisper.cpp/issues/89
|
||||||
|
|
||||||
Please include the following information:
|
Please include the following information:
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
|
|||||||
```
|
```
|
||||||
./build/bin/whisper-cli -h
|
./build/bin/whisper-cli -h
|
||||||
|
|
||||||
usage: ./build-pkg/bin/whisper-cli [options] file0.wav file1.wav ...
|
usage: ./build/bin/whisper-cli [options] file0 file1 ...
|
||||||
|
supported audio formats: flac, mp3, ogg, wav
|
||||||
|
|
||||||
options:
|
options:
|
||||||
-h, --help [default] show this help message and exit
|
-h, --help [default] show this help message and exit
|
||||||
@ -24,6 +25,7 @@ options:
|
|||||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||||
|
-nth N, --no-speech-thold N [0.60 ] no speech threshold
|
||||||
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
|
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
|
||||||
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
|
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
|
||||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||||
@ -50,12 +52,13 @@ options:
|
|||||||
-dl, --detect-language [false ] exit after automatically detecting language
|
-dl, --detect-language [false ] exit after automatically detecting language
|
||||||
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
|
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
|
||||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||||
-f FNAME, --file FNAME [ ] input WAV file path
|
-f FNAME, --file FNAME [ ] input audio file path
|
||||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||||
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
|
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
|
||||||
-ls, --log-score [false ] log best decoder scores of tokens
|
-ls, --log-score [false ] log best decoder scores of tokens
|
||||||
-ng, --no-gpu [false ] disable GPU
|
-ng, --no-gpu [false ] disable GPU
|
||||||
-fa, --flash-attn [false ] flash attention
|
-fa, --flash-attn [false ] flash attention
|
||||||
|
-sns, --suppress-nst [false ] suppress non-speech tokens
|
||||||
--suppress-regex REGEX [ ] regular expression matching tokens to suppress
|
--suppress-regex REGEX [ ] regular expression matching tokens to suppress
|
||||||
--grammar GRAMMAR [ ] GBNF grammar to guide decoding
|
--grammar GRAMMAR [ ] GBNF grammar to guide decoding
|
||||||
--grammar-rule RULE [ ] top-level GBNF grammar rule name
|
--grammar-rule RULE [ ] top-level GBNF grammar rule name
|
||||||
|
@ -19,10 +19,6 @@
|
|||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// helper function to replace substrings
|
// helper function to replace substrings
|
||||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||||
for (size_t pos = 0; ; pos += replace.length()) {
|
for (size_t pos = 0; ; pos += replace.length()) {
|
||||||
@ -379,15 +375,7 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
for (int i = 0; i < n_segments; ++i) {
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
@ -402,19 +390,9 @@ static bool output_txt(struct whisper_context * ctx, const char * fname, const w
|
|||||||
|
|
||||||
fout << speaker << text << "\n";
|
fout << speaker << text << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
fout << "WEBVTT\n\n";
|
fout << "WEBVTT\n\n";
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
@ -434,19 +412,9 @@ static bool output_vtt(struct whisper_context * ctx, const char * fname, const w
|
|||||||
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
|
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
|
||||||
fout << speaker << text << "\n\n";
|
fout << speaker << text << "\n\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
for (int i = 0; i < n_segments; ++i) {
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
@ -463,8 +431,6 @@ static bool output_srt(struct whisper_context * ctx, const char * fname, const w
|
|||||||
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
|
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
|
||||||
fout << speaker << text << "\n\n";
|
fout << speaker << text << "\n\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static char * escape_double_quotes_and_backslashes(const char * str) {
|
static char * escape_double_quotes_and_backslashes(const char * str) {
|
||||||
@ -530,15 +496,7 @@ static char * escape_double_quotes_in_csv(const char * str) {
|
|||||||
return escaped;
|
return escaped;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
fout << "start,end,";
|
fout << "start,end,";
|
||||||
if (params.diarize && pcmf32s.size() == 2)
|
if (params.diarize && pcmf32s.size() == 2)
|
||||||
@ -561,14 +519,9 @@ static bool output_csv(struct whisper_context * ctx, const char * fname, const w
|
|||||||
}
|
}
|
||||||
fout << "\"" << text_escaped << "\"\n";
|
fout << "\"" << text_escaped << "\"\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
|
static void output_score(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
// fprintf(stderr,"segments: %d\n",n_segments);
|
// fprintf(stderr,"segments: %d\n",n_segments);
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
for (int i = 0; i < n_segments; ++i) {
|
||||||
@ -581,16 +534,14 @@ static bool output_score(struct whisper_context * ctx, const char * fname, const
|
|||||||
// fprintf(stderr,"token: %s %f\n",token,probability);
|
// fprintf(stderr,"token: %s %f\n",token,probability);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_json(
|
static void output_json(
|
||||||
struct whisper_context * ctx,
|
struct whisper_context * ctx,
|
||||||
const char * fname,
|
std::ofstream & fout,
|
||||||
const whisper_params & params,
|
const whisper_params & params,
|
||||||
std::vector<std::vector<float>> pcmf32s,
|
std::vector<std::vector<float>> pcmf32s) {
|
||||||
bool full) {
|
const bool full = params.output_jsn_full;
|
||||||
std::ofstream fout(fname);
|
|
||||||
int indent = 0;
|
int indent = 0;
|
||||||
|
|
||||||
auto doindent = [&]() {
|
auto doindent = [&]() {
|
||||||
@ -670,12 +621,6 @@ static bool output_json(
|
|||||||
end_obj(end);
|
end_obj(end);
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
start_obj(nullptr);
|
start_obj(nullptr);
|
||||||
value_s("systeminfo", whisper_print_system_info(), false);
|
value_s("systeminfo", whisper_print_system_info(), false);
|
||||||
start_obj("model");
|
start_obj("model");
|
||||||
@ -749,17 +694,12 @@ static bool output_json(
|
|||||||
|
|
||||||
end_arr(true);
|
end_arr(true);
|
||||||
end_obj(true);
|
end_obj(true);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// karaoke video generation
|
// karaoke video generation
|
||||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||||
// TODO: font parameter adjustments
|
// TODO: font parameter adjustments
|
||||||
static bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
|
static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s, const char * fname_inp, float t_sec, const char * fname_out) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
static const char * font = params.font_path.c_str();
|
static const char * font = params.font_path.c_str();
|
||||||
|
|
||||||
std::ifstream fin(font);
|
std::ifstream fin(font);
|
||||||
@ -875,20 +815,12 @@ static bool output_wts(struct whisper_context * ctx, const char * fname, const c
|
|||||||
|
|
||||||
fout.close();
|
fout.close();
|
||||||
|
|
||||||
fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
|
fprintf(stderr, "# %s: run 'source %s' to generate karaoke video\n", __func__, fname_out);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||||
std::ofstream fout(fname);
|
|
||||||
if (!fout.is_open()) {
|
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
|
||||||
|
|
||||||
fout << "[by:whisper.cpp]\n";
|
fout << "[by:whisper.cpp]\n";
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
@ -916,8 +848,6 @@ static bool output_lrc(struct whisper_context * ctx, const char * fname, const w
|
|||||||
|
|
||||||
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
|
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1066,8 +996,55 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto & fname_inp = params.fname_inp[f];
|
||||||
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
struct fout_factory {
|
||||||
|
std::string fname_out;
|
||||||
|
const size_t basename_length;
|
||||||
|
const bool is_stdout;
|
||||||
|
bool used_stdout;
|
||||||
|
decltype(whisper_print_segment_callback) * const print_segment_callback;
|
||||||
|
std::ofstream fout;
|
||||||
|
|
||||||
|
fout_factory (const std::string & fname_out_, const std::string & fname_inp, whisper_params & params) :
|
||||||
|
fname_out{!fname_out_.empty() ? fname_out_ : fname_inp},
|
||||||
|
basename_length{fname_out.size()},
|
||||||
|
is_stdout{fname_out == "-"},
|
||||||
|
used_stdout{},
|
||||||
|
print_segment_callback{is_stdout ? nullptr : whisper_print_segment_callback} {
|
||||||
|
if (!print_segment_callback) {
|
||||||
|
params.print_progress = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool open(const char * ext, const char * function) {
|
||||||
|
if (is_stdout) {
|
||||||
|
if (used_stdout) {
|
||||||
|
fprintf(stderr, "warning: Not appending multiple file formats to stdout\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
used_stdout = true;
|
||||||
|
#ifdef _WIN32
|
||||||
|
fout = std::ofstream{"CON"};
|
||||||
|
#else
|
||||||
|
fout = std::ofstream{"/dev/stdout"};
|
||||||
|
#endif
|
||||||
|
// Not using fprintf stderr here because it might equal stdout
|
||||||
|
// Also assuming /dev is mounted
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
fname_out.resize(basename_length);
|
||||||
|
fname_out += ext;
|
||||||
|
fout = std::ofstream{fname_out};
|
||||||
|
if (!fout.is_open()) {
|
||||||
|
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "%s: saving output to '%s'\n", function, fname_out.c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} fout_factory{f < (int) params.fname_out.size() ? params.fname_out[f] : "", fname_inp, params};
|
||||||
|
|
||||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||||
@ -1172,7 +1149,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// this callback is called on each new segment
|
// this callback is called on each new segment
|
||||||
if (!wparams.print_realtime) {
|
if (!wparams.print_realtime) {
|
||||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
wparams.new_segment_callback = fout_factory.print_segment_callback;
|
||||||
wparams.new_segment_callback_user_data = &user_data;
|
wparams.new_segment_callback_user_data = &user_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1214,54 +1191,26 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// output stuff
|
// output stuff
|
||||||
{
|
{
|
||||||
printf("\n");
|
// macros to stringify function name
|
||||||
|
#define output_func(func, ext, param, ...) if (param && fout_factory.open(ext, #func)) {\
|
||||||
|
func(ctx, fout_factory.fout, params, __VA_ARGS__); \
|
||||||
|
}
|
||||||
|
#define output_ext(ext, ...) output_func(output_##ext, "." #ext, params.output_##ext, __VA_ARGS__)
|
||||||
|
|
||||||
// output to text file
|
output_ext(txt, pcmf32s);
|
||||||
if (params.output_txt) {
|
output_ext(vtt, pcmf32s);
|
||||||
const auto fname_txt = fname_out + ".txt";
|
output_ext(srt, pcmf32s);
|
||||||
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
|
output_ext(wts, pcmf32s, fname_inp.c_str(), float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, fout_factory.fname_out.c_str());
|
||||||
}
|
output_ext(csv, pcmf32s);
|
||||||
|
output_func(output_json, ".json", params.output_jsn, pcmf32s);
|
||||||
|
output_ext(lrc, pcmf32s);
|
||||||
|
output_func(output_score, ".score.txt", params.log_score, pcmf32s);
|
||||||
|
|
||||||
// output to VTT file
|
#undef output_ext
|
||||||
if (params.output_vtt) {
|
#undef output_func
|
||||||
const auto fname_vtt = fname_out + ".vtt";
|
|
||||||
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output to SRT file
|
if (fout_factory.is_stdout && !fout_factory.used_stdout) {
|
||||||
if (params.output_srt) {
|
fprintf(stderr, "warning: '--output-file -' used without any other '--output-*'");
|
||||||
const auto fname_srt = fname_out + ".srt";
|
|
||||||
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output to WTS file
|
|
||||||
if (params.output_wts) {
|
|
||||||
const auto fname_wts = fname_out + ".wts";
|
|
||||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output to CSV file
|
|
||||||
if (params.output_csv) {
|
|
||||||
const auto fname_csv = fname_out + ".csv";
|
|
||||||
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output to JSON file
|
|
||||||
if (params.output_jsn) {
|
|
||||||
const auto fname_jsn = fname_out + ".json";
|
|
||||||
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output to LRC file
|
|
||||||
if (params.output_lrc) {
|
|
||||||
const auto fname_lrc = fname_out + ".lrc";
|
|
||||||
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
|
|
||||||
}
|
|
||||||
|
|
||||||
// output to score file
|
|
||||||
if (params.log_score) {
|
|
||||||
const auto fname_score = fname_out + ".score.txt";
|
|
||||||
output_score(ctx, fname_score.c_str(), params, pcmf32s);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
// Speak short text commands to the microphone.
|
// Speak short text commands to the microphone.
|
||||||
// This program will detect your voice command and convert them to text.
|
// This program will detect your voice command and convert them to text.
|
||||||
//
|
//
|
||||||
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
// ref: https://github.com/ggml-org/whisper.cpp/issues/171
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "common-sdl.h"
|
#include "common-sdl.h"
|
||||||
|
@ -26,10 +26,6 @@
|
|||||||
#define MINIAUDIO_IMPLEMENTATION
|
#define MINIAUDIO_IMPLEMENTATION
|
||||||
#include "miniaudio.h"
|
#include "miniaudio.h"
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <io.h>
|
#include <io.h>
|
||||||
|
@ -10,10 +10,6 @@
|
|||||||
#include <regex>
|
#include <regex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Function to check if the next argument exists
|
// Function to check if the next argument exists
|
||||||
static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
|
static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
|
||||||
if (i + 1 < argc && argv[i + 1][0] != '-') {
|
if (i + 1 < argc && argv[i + 1][0] != '-') {
|
||||||
|
@ -194,7 +194,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
|
|||||||
AVIOContext *avio_ctx;
|
AVIOContext *avio_ctx;
|
||||||
AVStream *stream;
|
AVStream *stream;
|
||||||
AVCodecContext *codec;
|
AVCodecContext *codec;
|
||||||
AVPacket packet;
|
AVPacket *packet;
|
||||||
AVFrame *frame;
|
AVFrame *frame;
|
||||||
struct SwrContext *swr;
|
struct SwrContext *swr;
|
||||||
u8 *avio_ctx_buffer;
|
u8 *avio_ctx_buffer;
|
||||||
@ -249,6 +249,20 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
|
|||||||
/* prepare resampler */
|
/* prepare resampler */
|
||||||
swr = swr_alloc();
|
swr = swr_alloc();
|
||||||
|
|
||||||
|
#if LIBAVCODEC_VERSION_MAJOR > 60
|
||||||
|
AVChannelLayout in_ch_layout = codec->ch_layout;
|
||||||
|
AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO;
|
||||||
|
|
||||||
|
/* Set the source audio layout as-is */
|
||||||
|
av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0);
|
||||||
|
av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0);
|
||||||
|
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
|
||||||
|
|
||||||
|
/* Convert it into 16khz Mono */
|
||||||
|
av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0);
|
||||||
|
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
|
||||||
|
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
|
||||||
|
#else
|
||||||
av_opt_set_int(swr, "in_channel_count", codec->channels, 0);
|
av_opt_set_int(swr, "in_channel_count", codec->channels, 0);
|
||||||
av_opt_set_int(swr, "out_channel_count", 1, 0);
|
av_opt_set_int(swr, "out_channel_count", 1, 0);
|
||||||
av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0);
|
av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0);
|
||||||
@ -257,6 +271,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
|
|||||||
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
|
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
|
||||||
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
|
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
|
||||||
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
|
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
swr_init(swr);
|
swr_init(swr);
|
||||||
if (!swr_is_initialized(swr)) {
|
if (!swr_is_initialized(swr)) {
|
||||||
@ -264,7 +279,11 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
av_init_packet(&packet);
|
packet=av_packet_alloc();
|
||||||
|
if (!packet) {
|
||||||
|
LOG("Error allocating the packet\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
frame = av_frame_alloc();
|
frame = av_frame_alloc();
|
||||||
if (!frame) {
|
if (!frame) {
|
||||||
LOG("Error allocating the frame\n");
|
LOG("Error allocating the frame\n");
|
||||||
@ -274,8 +293,8 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
|
|||||||
/* iterate through frames */
|
/* iterate through frames */
|
||||||
*data = NULL;
|
*data = NULL;
|
||||||
*size = 0;
|
*size = 0;
|
||||||
while (av_read_frame(fmt_ctx, &packet) >= 0) {
|
while (av_read_frame(fmt_ctx, packet) >= 0) {
|
||||||
avcodec_send_packet(codec, &packet);
|
avcodec_send_packet(codec, packet);
|
||||||
|
|
||||||
err = avcodec_receive_frame(codec, frame);
|
err = avcodec_receive_frame(codec, frame);
|
||||||
if (err == AVERROR(EAGAIN))
|
if (err == AVERROR(EAGAIN))
|
||||||
@ -286,10 +305,11 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
|
|||||||
/* Flush any remaining conversion buffers... */
|
/* Flush any remaining conversion buffers... */
|
||||||
convert_frame(swr, codec, frame, data, size, true);
|
convert_frame(swr, codec, frame, data, size, true);
|
||||||
|
|
||||||
|
av_packet_free(&packet);
|
||||||
av_frame_free(&frame);
|
av_frame_free(&frame);
|
||||||
swr_free(&swr);
|
swr_free(&swr);
|
||||||
//avio_context_free(); // todo?
|
//avio_context_free(); // todo?
|
||||||
avcodec_close(codec);
|
avcodec_free_context(&codec);
|
||||||
avformat_close_input(&fmt_ctx);
|
avformat_close_input(&fmt_ctx);
|
||||||
avformat_free_context(fmt_ctx);
|
avformat_free_context(fmt_ctx);
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#
|
#
|
||||||
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
|
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
|
||||||
# Idea by @semiformal-net
|
# Idea by @semiformal-net
|
||||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/185
|
# ref: https://github.com/ggml-org/whisper.cpp/issues/185
|
||||||
#
|
#
|
||||||
|
|
||||||
set -eo pipefail
|
set -eo pipefail
|
||||||
|
@ -1,39 +1,115 @@
|
|||||||
import http.server
|
import http.server
|
||||||
import socketserver
|
import socketserver
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
SCRIPT_DIR = Path(__file__).parent.absolute()
|
SCRIPT_DIR = Path(__file__).parent.absolute()
|
||||||
DIRECTORY = os.path.join(SCRIPT_DIR, "../build-em/bin")
|
DIRECTORY = os.path.join(SCRIPT_DIR, "../build-em/bin")
|
||||||
DIRECTORY = os.path.abspath(DIRECTORY)
|
DIRECTORY = os.path.abspath(DIRECTORY)
|
||||||
|
|
||||||
|
# The context root we want for all applications
|
||||||
|
CONTEXT_ROOT = "/whisper.cpp"
|
||||||
|
|
||||||
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, directory=DIRECTORY, **kwargs)
|
super().__init__(*args, directory=DIRECTORY, **kwargs)
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
# If requesting a worker file from any subdirectory
|
# Redirect root to the context root
|
||||||
if '.worker.js' in self.path:
|
if self.path == '/':
|
||||||
|
self.send_response(302)
|
||||||
|
self.send_header('Location', CONTEXT_ROOT + '/')
|
||||||
|
self.end_headers()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle requests under the context root
|
||||||
|
if self.path.startswith(CONTEXT_ROOT):
|
||||||
|
# Remove the context root prefix to get the actual path
|
||||||
|
actual_path = self.path[len(CONTEXT_ROOT):]
|
||||||
|
|
||||||
|
if not actual_path:
|
||||||
|
self.send_response(302)
|
||||||
|
self.send_header('Location', CONTEXT_ROOT + '/')
|
||||||
|
self.end_headers()
|
||||||
|
return
|
||||||
|
|
||||||
|
if '.worker.js' in actual_path:
|
||||||
|
worker_file = os.path.basename(actual_path)
|
||||||
|
worker_path = os.path.join(DIRECTORY, worker_file)
|
||||||
|
|
||||||
|
if os.path.exists(worker_path):
|
||||||
|
print(f"Found worker file: {worker_path}")
|
||||||
|
self.path = '/' + worker_file
|
||||||
|
else:
|
||||||
|
print(f"Worker file not found: {worker_path}")
|
||||||
|
|
||||||
|
elif actual_path == '/':
|
||||||
|
self.path = '/whisper.wasm/index.html'
|
||||||
|
elif actual_path.startswith('/bench.wasm/') or actual_path.startswith('/command.wasm/') or actual_path.startswith('/stream.wasm/'):
|
||||||
|
# Keep the path as is, just remove the context root
|
||||||
|
self.path = actual_path
|
||||||
|
# For all other paths under the context root
|
||||||
|
else:
|
||||||
|
# Check if this is a request to a file in whisper.wasm
|
||||||
|
potential_file = os.path.join(DIRECTORY, 'whisper.wasm', actual_path.lstrip('/'))
|
||||||
|
if os.path.exists(potential_file) and not os.path.isdir(potential_file):
|
||||||
|
self.path = '/whisper.wasm' + actual_path
|
||||||
|
else:
|
||||||
|
# Try to resolve the file from the base directory
|
||||||
|
potential_file = os.path.join(DIRECTORY, actual_path.lstrip('/'))
|
||||||
|
if os.path.exists(potential_file):
|
||||||
|
self.path = actual_path
|
||||||
|
|
||||||
|
# For direct requests to worker files (without context root as these
|
||||||
|
# are in the build-em/bin directory
|
||||||
|
elif '.worker.js' in self.path:
|
||||||
worker_file = os.path.basename(self.path)
|
worker_file = os.path.basename(self.path)
|
||||||
worker_path = os.path.join(DIRECTORY, worker_file)
|
worker_path = os.path.join(DIRECTORY, worker_file)
|
||||||
|
|
||||||
if os.path.exists(worker_path):
|
if os.path.exists(worker_path):
|
||||||
self.path = '/' + worker_file
|
self.path = '/' + worker_file
|
||||||
|
|
||||||
|
# Handle coi-serviceworker.js separately
|
||||||
|
if 'coi-serviceworker.js' in self.path:
|
||||||
|
worker_file = "coi-serviceworker.js"
|
||||||
|
worker_path = os.path.join(SCRIPT_DIR, worker_file)
|
||||||
|
if os.path.exists(worker_path):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-type', 'application/javascript')
|
||||||
|
self.end_headers()
|
||||||
|
with open(worker_path, 'rb') as file:
|
||||||
|
self.wfile.write(file.read())
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print(f"Warning: Could not find {worker_path}")
|
||||||
|
|
||||||
return super().do_GET()
|
return super().do_GET()
|
||||||
|
|
||||||
def end_headers(self):
|
def end_headers(self):
|
||||||
# Add required headers for SharedArrayBuffer
|
# Add required headers for SharedArrayBuffer
|
||||||
self.send_header("Cross-Origin-Opener-Policy", "same-origin")
|
self.send_header("Cross-Origin-Opener-Policy", "same-origin")
|
||||||
self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
|
self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
|
||||||
self.send_header("Access-Control-Allow-Origin", "*");
|
self.send_header("Access-Control-Allow-Origin", "*")
|
||||||
super().end_headers()
|
super().end_headers()
|
||||||
|
|
||||||
PORT = 8000
|
PORT = 8000
|
||||||
|
|
||||||
with socketserver.TCPServer(("", PORT), CustomHTTPRequestHandler) as httpd:
|
# Enable address reuse
|
||||||
print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
|
class CustomServer(socketserver.TCPServer):
|
||||||
try:
|
allow_reuse_address = True
|
||||||
httpd.serve_forever()
|
|
||||||
except KeyboardInterrupt:
|
try:
|
||||||
print("\nServer stopped.")
|
with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd:
|
||||||
|
print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
|
||||||
|
print(f"Application context root: http://localhost:{PORT}{CONTEXT_ROOT}/")
|
||||||
|
try:
|
||||||
|
httpd.serve_forever()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nServer stopped.")
|
||||||
|
# Force complete exit
|
||||||
|
sys.exit(0)
|
||||||
|
except OSError as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -14,10 +14,6 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
||||||
#endif
|
|
||||||
|
|
||||||
using namespace httplib;
|
using namespace httplib;
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
@ -79,6 +75,7 @@ struct whisper_params {
|
|||||||
bool use_gpu = true;
|
bool use_gpu = true;
|
||||||
bool flash_attn = false;
|
bool flash_attn = false;
|
||||||
bool suppress_nst = false;
|
bool suppress_nst = false;
|
||||||
|
bool no_context = false;
|
||||||
|
|
||||||
std::string language = "en";
|
std::string language = "en";
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
@ -140,6 +137,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
|
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
|
||||||
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
|
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
|
||||||
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
|
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
|
||||||
|
fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false");
|
||||||
|
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,6 +185,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
|||||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||||
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
|
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
|
||||||
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
|
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
|
||||||
|
else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; }
|
||||||
|
|
||||||
// server params
|
// server params
|
||||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||||
@ -506,6 +506,10 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
|||||||
{
|
{
|
||||||
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
|
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
|
||||||
}
|
}
|
||||||
|
if (req.has_file("no_context"))
|
||||||
|
{
|
||||||
|
params.no_context = parse_str_to_bool(req.get_file_value("no_context").content);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -818,6 +822,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
wparams.no_timestamps = params.no_timestamps;
|
wparams.no_timestamps = params.no_timestamps;
|
||||||
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
|
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
|
||||||
|
wparams.no_context = params.no_context;
|
||||||
|
|
||||||
wparams.suppress_nst = params.suppress_nst;
|
wparams.suppress_nst = params.suppress_nst;
|
||||||
|
|
||||||
@ -834,33 +839,25 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.progress_callback_user_data = &user_data;
|
wparams.progress_callback_user_data = &user_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// examples for abort mechanism
|
// tell whisper to abort if the HTTP connection closed
|
||||||
// in examples below, we do not abort the processing, but we could if the flag is set to true
|
wparams.abort_callback = [](void *user_data) {
|
||||||
|
// user_data is a pointer to our Request
|
||||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
auto req_ptr = static_cast<const httplib::Request*>(user_data);
|
||||||
{
|
return req_ptr->is_connection_closed();
|
||||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
};
|
||||||
|
wparams.abort_callback_user_data = (void*)&req;
|
||||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
|
||||||
bool is_aborted = *(bool*)user_data;
|
|
||||||
return !is_aborted;
|
|
||||||
};
|
|
||||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
|
||||||
}
|
|
||||||
|
|
||||||
// the callback is called before every computation - if it returns true, the computation is aborted
|
|
||||||
{
|
|
||||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
|
||||||
|
|
||||||
wparams.abort_callback = [](void * user_data) {
|
|
||||||
bool is_aborted = *(bool*)user_data;
|
|
||||||
return is_aborted;
|
|
||||||
};
|
|
||||||
wparams.abort_callback_user_data = &is_aborted;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||||
|
// handle failure or early abort
|
||||||
|
if (req.is_connection_closed()) {
|
||||||
|
// log client disconnect
|
||||||
|
fprintf(stderr, "client disconnected, aborted processing\n");
|
||||||
|
res.status = 499; // Client Closed Request (nginx convention)
|
||||||
|
res.set_content("{\"error\":\"client disconnected\"}", "application/json");
|
||||||
|
return;
|
||||||
|
}
|
||||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||||
|
res.status = 500; // Internal Server Error
|
||||||
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
|
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
|
||||||
res.set_content(error_resp, "application/json");
|
res.set_content(error_resp, "application/json");
|
||||||
return;
|
return;
|
||||||
@ -918,14 +915,26 @@ int main(int argc, char ** argv) {
|
|||||||
res.set_content(ss.str(), "text/vtt");
|
res.set_content(ss.str(), "text/vtt");
|
||||||
} else if (params.response_format == vjson_format) {
|
} else if (params.response_format == vjson_format) {
|
||||||
/* try to match openai/whisper's Python format */
|
/* try to match openai/whisper's Python format */
|
||||||
std::string results = output_str(ctx, params, pcmf32s);
|
std::string results = output_str(ctx, params, pcmf32s);
|
||||||
|
// Get language probabilities
|
||||||
|
std::vector<float> lang_probs(whisper_lang_max_id() + 1, 0.0f);
|
||||||
|
const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data());
|
||||||
json jres = json{
|
json jres = json{
|
||||||
{"task", params.translate ? "translate" : "transcribe"},
|
{"task", params.translate ? "translate" : "transcribe"},
|
||||||
{"language", whisper_lang_str_full(whisper_full_lang_id(ctx))},
|
{"language", whisper_lang_str_full(whisper_full_lang_id(ctx))},
|
||||||
{"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE},
|
{"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE},
|
||||||
{"text", results},
|
{"text", results},
|
||||||
{"segments", json::array()}
|
{"segments", json::array()},
|
||||||
|
{"detected_language", whisper_lang_str_full(detected_lang_id)},
|
||||||
|
{"detected_language_probability", lang_probs[detected_lang_id]},
|
||||||
|
{"language_probabilities", json::object()}
|
||||||
};
|
};
|
||||||
|
// Add all language probabilities
|
||||||
|
for (int i = 0; i <= whisper_lang_max_id(); ++i) {
|
||||||
|
if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities
|
||||||
|
jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
for (int i = 0; i < n_segments; ++i)
|
for (int i = 0; i < n_segments; ++i)
|
||||||
{
|
{
|
||||||
|
@ -12,9 +12,12 @@ if (WHISPER_SDL2)
|
|||||||
llama-context.cpp
|
llama-context.cpp
|
||||||
llama-cparams.cpp
|
llama-cparams.cpp
|
||||||
llama-grammar.cpp
|
llama-grammar.cpp
|
||||||
|
llama-graph.cpp
|
||||||
llama-hparams.cpp
|
llama-hparams.cpp
|
||||||
llama-impl.cpp
|
llama-impl.cpp
|
||||||
|
llama-io.cpp
|
||||||
llama-kv-cache.cpp
|
llama-kv-cache.cpp
|
||||||
|
llama-memory.cpp
|
||||||
llama-mmap.cpp
|
llama-mmap.cpp
|
||||||
llama-model-loader.cpp
|
llama-model-loader.cpp
|
||||||
llama-model.cpp
|
llama-model.cpp
|
||||||
|
@ -4,14 +4,13 @@
|
|||||||
#include "llama-mmap.h"
|
#include "llama-mmap.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
// vec
|
// vec
|
||||||
|
|
||||||
struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
||||||
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
|
|||||||
return tensors[il];
|
return tensors[il];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
|
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const {
|
||||||
ggml_tensor * layer_dir = tensor_for(il);
|
ggml_tensor * layer_dir = tensor_for(il);
|
||||||
if (layer_dir != nullptr) {
|
if (layer_dir != nullptr) {
|
||||||
cur = ggml_add(ctx, cur, layer_dir);
|
cur = ggml_add(ctx, cur, layer_dir);
|
||||||
@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
|
|||||||
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||||
auto it = ctx_map.find(buft);
|
auto it = ctx_map.find(buft);
|
||||||
if (it == ctx_map.end()) {
|
if (it == ctx_map.end()) {
|
||||||
struct ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_adapter_cvec::apply(
|
bool llama_adapter_cvec::apply(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const float * data,
|
const float * data,
|
||||||
size_t len,
|
size_t len,
|
||||||
@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
|
|||||||
// disable the current control vector (but leave allocated for later)
|
// disable the current control vector (but leave allocated for later)
|
||||||
layer_start = -1;
|
layer_start = -1;
|
||||||
layer_end = -1;
|
layer_end = -1;
|
||||||
return 0;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_embd != (int) hparams.n_embd) {
|
if (n_embd != (int) hparams.n_embd) {
|
||||||
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
|
||||||
return 1;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensors.empty()) {
|
if (tensors.empty()) {
|
||||||
if (!init(model)) {
|
if (!init(model)) {
|
||||||
return 1;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// lora
|
// lora
|
||||||
|
|
||||||
llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
|
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
|
||||||
const std::string name(w->name);
|
const std::string name(w->name);
|
||||||
|
|
||||||
const auto pos = ab_map.find(name);
|
const auto pos = ab_map.find(name);
|
||||||
@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
|
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
|
||||||
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
|
||||||
|
|
||||||
ggml_context * ctx_init;
|
ggml_context * ctx_init;
|
||||||
struct gguf_init_params meta_gguf_params = {
|
gguf_init_params meta_gguf_params = {
|
||||||
/* .no_alloc = */ true,
|
/* .no_alloc = */ true,
|
||||||
/* .ctx = */ &ctx_init,
|
/* .ctx = */ &ctx_init,
|
||||||
};
|
};
|
||||||
@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|||||||
auto it = ctx_map.find(buft);
|
auto it = ctx_map.find(buft);
|
||||||
if (it == ctx_map.end()) {
|
if (it == ctx_map.end()) {
|
||||||
// add a new context
|
// add a new context
|
||||||
struct ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
@ -248,6 +247,26 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get extra buffer types of the CPU
|
||||||
|
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
|
||||||
|
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||||
|
{
|
||||||
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||||
|
|
||||||
|
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||||
|
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
||||||
|
|
||||||
|
if (ggml_backend_dev_get_extra_bufts_fn) {
|
||||||
|
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
|
||||||
|
while (extra_bufts && *extra_bufts) {
|
||||||
|
buft_extra.emplace_back(*extra_bufts);
|
||||||
|
++extra_bufts;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// add tensors
|
// add tensors
|
||||||
for (auto & it : ab_map) {
|
for (auto & it : ab_map) {
|
||||||
const std::string & name = it.first;
|
const std::string & name = it.first;
|
||||||
@ -264,7 +283,23 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|||||||
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
|
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
|
||||||
|
|
||||||
|
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
|
||||||
|
for (auto & ex : buft_extra) {
|
||||||
|
if (ex == buft) {
|
||||||
|
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||||
|
|
||||||
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||||
|
|
||||||
|
ggml_context * dev_ctx = ctx_for_buft(buft);
|
||||||
// validate tensor shape
|
// validate tensor shape
|
||||||
if (is_token_embd) {
|
if (is_token_embd) {
|
||||||
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
||||||
@ -281,8 +316,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|||||||
}
|
}
|
||||||
|
|
||||||
// save tensor to adapter
|
// save tensor to adapter
|
||||||
struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
|
||||||
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
|
||||||
ggml_set_name(tensor_a, w.a->name);
|
ggml_set_name(tensor_a, w.a->name);
|
||||||
ggml_set_name(tensor_b, w.b->name);
|
ggml_set_name(tensor_b, w.b->name);
|
||||||
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
|
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
|
||||||
@ -308,7 +343,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|||||||
{
|
{
|
||||||
llama_file gguf_file(path_lora, "rb");
|
llama_file gguf_file(path_lora, "rb");
|
||||||
std::vector<uint8_t> read_buf;
|
std::vector<uint8_t> read_buf;
|
||||||
auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
|
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
|
||||||
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
|
||||||
size_t size = ggml_nbytes(orig);
|
size_t size = ggml_nbytes(orig);
|
||||||
read_buf.resize(size);
|
read_buf.resize(size);
|
||||||
@ -327,8 +362,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
|
|||||||
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
|
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
|
||||||
struct llama_adapter_lora * adapter = new llama_adapter_lora();
|
llama_adapter_lora * adapter = new llama_adapter_lora();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
|
||||||
@ -342,6 +377,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
|
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
||||||
delete adapter;
|
delete adapter;
|
||||||
}
|
}
|
||||||
|
@ -15,11 +15,11 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
struct llama_adapter_cvec {
|
struct llama_adapter_cvec {
|
||||||
struct ggml_tensor * tensor_for(int il) const;
|
ggml_tensor * tensor_for(int il) const;
|
||||||
|
|
||||||
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
|
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const;
|
||||||
|
|
||||||
int32_t apply(
|
bool apply(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const float * data,
|
const float * data,
|
||||||
size_t len,
|
size_t len,
|
||||||
@ -36,7 +36,7 @@ private:
|
|||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
std::vector<ggml_tensor *> tensors; // per layer
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -44,8 +44,8 @@ private:
|
|||||||
//
|
//
|
||||||
|
|
||||||
struct llama_adapter_lora_weight {
|
struct llama_adapter_lora_weight {
|
||||||
struct ggml_tensor * a = nullptr;
|
ggml_tensor * a = nullptr;
|
||||||
struct ggml_tensor * b = nullptr;
|
ggml_tensor * b = nullptr;
|
||||||
|
|
||||||
// get actual scale based on rank and alpha
|
// get actual scale based on rank and alpha
|
||||||
float get_scale(float alpha, float adapter_scale) const {
|
float get_scale(float alpha, float adapter_scale) const {
|
||||||
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llama_adapter_lora_weight() = default;
|
llama_adapter_lora_weight() = default;
|
||||||
llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
|
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_adapter_lora {
|
struct llama_adapter_lora {
|
||||||
// map tensor name to lora_a_b
|
// map tensor name to lora_a_b
|
||||||
std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
|
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
@ -70,5 +70,7 @@ struct llama_adapter_lora {
|
|||||||
llama_adapter_lora() = default;
|
llama_adapter_lora() = default;
|
||||||
~llama_adapter_lora() = default;
|
~llama_adapter_lora() = default;
|
||||||
|
|
||||||
llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
|
llama_adapter_lora_weight * get_weight(ggml_tensor * w);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_LLAMA, "llama" },
|
{ LLM_ARCH_LLAMA, "llama" },
|
||||||
|
{ LLM_ARCH_LLAMA4, "llama4" },
|
||||||
{ LLM_ARCH_DECI, "deci" },
|
{ LLM_ARCH_DECI, "deci" },
|
||||||
{ LLM_ARCH_FALCON, "falcon" },
|
{ LLM_ARCH_FALCON, "falcon" },
|
||||||
{ LLM_ARCH_GROK, "grok" },
|
{ LLM_ARCH_GROK, "grok" },
|
||||||
@ -18,6 +19,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_REFACT, "refact" },
|
{ LLM_ARCH_REFACT, "refact" },
|
||||||
{ LLM_ARCH_BERT, "bert" },
|
{ LLM_ARCH_BERT, "bert" },
|
||||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||||
|
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||||
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
||||||
{ LLM_ARCH_BLOOM, "bloom" },
|
{ LLM_ARCH_BLOOM, "bloom" },
|
||||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||||
@ -25,6 +27,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_QWEN2, "qwen2" },
|
{ LLM_ARCH_QWEN2, "qwen2" },
|
||||||
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
|
||||||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||||
|
{ LLM_ARCH_QWEN3, "qwen3" },
|
||||||
|
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
|
||||||
{ LLM_ARCH_PHI2, "phi2" },
|
{ LLM_ARCH_PHI2, "phi2" },
|
||||||
{ LLM_ARCH_PHI3, "phi3" },
|
{ LLM_ARCH_PHI3, "phi3" },
|
||||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||||
@ -36,6 +40,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||||
{ LLM_ARCH_GEMMA, "gemma" },
|
{ LLM_ARCH_GEMMA, "gemma" },
|
||||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||||
|
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||||
{ LLM_ARCH_MAMBA, "mamba" },
|
{ LLM_ARCH_MAMBA, "mamba" },
|
||||||
{ LLM_ARCH_XVERSE, "xverse" },
|
{ LLM_ARCH_XVERSE, "xverse" },
|
||||||
@ -50,6 +55,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
{ LLM_ARCH_DEEPSEEK, "deepseek" },
|
||||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||||
{ LLM_ARCH_CHATGLM, "chatglm" },
|
{ LLM_ARCH_CHATGLM, "chatglm" },
|
||||||
|
{ LLM_ARCH_GLM4, "glm4" },
|
||||||
{ LLM_ARCH_BITNET, "bitnet" },
|
{ LLM_ARCH_BITNET, "bitnet" },
|
||||||
{ LLM_ARCH_T5, "t5" },
|
{ LLM_ARCH_T5, "t5" },
|
||||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||||
@ -58,10 +64,14 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||||||
{ LLM_ARCH_EXAONE, "exaone" },
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||||
|
{ LLM_ARCH_RWKV7, "rwkv7" },
|
||||||
|
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
||||||
{ LLM_ARCH_GRANITE, "granite" },
|
{ LLM_ARCH_GRANITE, "granite" },
|
||||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||||
|
{ LLM_ARCH_PLM, "plm" },
|
||||||
|
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -70,6 +80,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
|
||||||
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
||||||
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
||||||
|
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
|
||||||
{ LLM_KV_GENERAL_NAME, "general.name" },
|
{ LLM_KV_GENERAL_NAME, "general.name" },
|
||||||
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
||||||
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
||||||
@ -96,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||||
|
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
||||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||||
@ -108,23 +120,30 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
||||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||||
|
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
|
||||||
|
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||||
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
{ LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" },
|
||||||
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
{ LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" },
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" },
|
||||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
{ LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" },
|
||||||
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
{ LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" },
|
||||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
|
||||||
|
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
|
||||||
|
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||||
|
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||||
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
|
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||||
|
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||||
|
|
||||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||||
@ -223,6 +242,35 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_LLAMA4,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_DECI,
|
LLM_ARCH_DECI,
|
||||||
{
|
{
|
||||||
@ -426,6 +474,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_NOMIC_BERT_MOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||||
|
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_JINA_BERT_V2,
|
LLM_ARCH_JINA_BERT_V2,
|
||||||
{
|
{
|
||||||
@ -554,6 +620,45 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_QWEN3,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_QWEN3MOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
{
|
{
|
||||||
@ -766,6 +871,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_GEMMA3,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_STARCODER2,
|
LLM_ARCH_STARCODER2,
|
||||||
{
|
{
|
||||||
@ -999,6 +1125,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||||
|
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
||||||
|
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
||||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
@ -1015,6 +1143,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_PLM,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||||
|
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_CHATGLM,
|
LLM_ARCH_CHATGLM,
|
||||||
{
|
{
|
||||||
@ -1033,6 +1177,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_GLM4,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_BITNET,
|
LLM_ARCH_BITNET,
|
||||||
{
|
{
|
||||||
@ -1217,6 +1380,74 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_RWKV7,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||||
|
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
|
||||||
|
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
|
||||||
|
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_ARWKV7,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
|
||||||
|
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_GRANITE,
|
LLM_ARCH_GRANITE,
|
||||||
{
|
{
|
||||||
@ -1296,6 +1527,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||||||
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_BAILINGMOE,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
@ -1333,23 +1587,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||||||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
|
||||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
@ -1376,6 +1615,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||||||
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
@ -1394,6 +1639,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
@ -1401,6 +1649,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
enum llm_arch {
|
enum llm_arch {
|
||||||
LLM_ARCH_LLAMA,
|
LLM_ARCH_LLAMA,
|
||||||
|
LLM_ARCH_LLAMA4,
|
||||||
LLM_ARCH_DECI,
|
LLM_ARCH_DECI,
|
||||||
LLM_ARCH_FALCON,
|
LLM_ARCH_FALCON,
|
||||||
LLM_ARCH_BAICHUAN,
|
LLM_ARCH_BAICHUAN,
|
||||||
@ -22,6 +23,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_REFACT,
|
LLM_ARCH_REFACT,
|
||||||
LLM_ARCH_BERT,
|
LLM_ARCH_BERT,
|
||||||
LLM_ARCH_NOMIC_BERT,
|
LLM_ARCH_NOMIC_BERT,
|
||||||
|
LLM_ARCH_NOMIC_BERT_MOE,
|
||||||
LLM_ARCH_JINA_BERT_V2,
|
LLM_ARCH_JINA_BERT_V2,
|
||||||
LLM_ARCH_BLOOM,
|
LLM_ARCH_BLOOM,
|
||||||
LLM_ARCH_STABLELM,
|
LLM_ARCH_STABLELM,
|
||||||
@ -29,6 +31,8 @@ enum llm_arch {
|
|||||||
LLM_ARCH_QWEN2,
|
LLM_ARCH_QWEN2,
|
||||||
LLM_ARCH_QWEN2MOE,
|
LLM_ARCH_QWEN2MOE,
|
||||||
LLM_ARCH_QWEN2VL,
|
LLM_ARCH_QWEN2VL,
|
||||||
|
LLM_ARCH_QWEN3,
|
||||||
|
LLM_ARCH_QWEN3MOE,
|
||||||
LLM_ARCH_PHI2,
|
LLM_ARCH_PHI2,
|
||||||
LLM_ARCH_PHI3,
|
LLM_ARCH_PHI3,
|
||||||
LLM_ARCH_PHIMOE,
|
LLM_ARCH_PHIMOE,
|
||||||
@ -40,6 +44,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_MINICPM3,
|
LLM_ARCH_MINICPM3,
|
||||||
LLM_ARCH_GEMMA,
|
LLM_ARCH_GEMMA,
|
||||||
LLM_ARCH_GEMMA2,
|
LLM_ARCH_GEMMA2,
|
||||||
|
LLM_ARCH_GEMMA3,
|
||||||
LLM_ARCH_STARCODER2,
|
LLM_ARCH_STARCODER2,
|
||||||
LLM_ARCH_MAMBA,
|
LLM_ARCH_MAMBA,
|
||||||
LLM_ARCH_XVERSE,
|
LLM_ARCH_XVERSE,
|
||||||
@ -54,6 +59,7 @@ enum llm_arch {
|
|||||||
LLM_ARCH_DEEPSEEK,
|
LLM_ARCH_DEEPSEEK,
|
||||||
LLM_ARCH_DEEPSEEK2,
|
LLM_ARCH_DEEPSEEK2,
|
||||||
LLM_ARCH_CHATGLM,
|
LLM_ARCH_CHATGLM,
|
||||||
|
LLM_ARCH_GLM4,
|
||||||
LLM_ARCH_BITNET,
|
LLM_ARCH_BITNET,
|
||||||
LLM_ARCH_T5,
|
LLM_ARCH_T5,
|
||||||
LLM_ARCH_T5ENCODER,
|
LLM_ARCH_T5ENCODER,
|
||||||
@ -62,10 +68,14 @@ enum llm_arch {
|
|||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
LLM_ARCH_RWKV6,
|
LLM_ARCH_RWKV6,
|
||||||
LLM_ARCH_RWKV6QWEN2,
|
LLM_ARCH_RWKV6QWEN2,
|
||||||
|
LLM_ARCH_RWKV7,
|
||||||
|
LLM_ARCH_ARWKV7,
|
||||||
LLM_ARCH_GRANITE,
|
LLM_ARCH_GRANITE,
|
||||||
LLM_ARCH_GRANITE_MOE,
|
LLM_ARCH_GRANITE_MOE,
|
||||||
LLM_ARCH_CHAMELEON,
|
LLM_ARCH_CHAMELEON,
|
||||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||||
|
LLM_ARCH_PLM,
|
||||||
|
LLM_ARCH_BAILINGMOE,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -74,6 +84,7 @@ enum llm_kv {
|
|||||||
LLM_KV_GENERAL_ARCHITECTURE,
|
LLM_KV_GENERAL_ARCHITECTURE,
|
||||||
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||||
LLM_KV_GENERAL_ALIGNMENT,
|
LLM_KV_GENERAL_ALIGNMENT,
|
||||||
|
LLM_KV_GENERAL_FILE_TYPE,
|
||||||
LLM_KV_GENERAL_NAME,
|
LLM_KV_GENERAL_NAME,
|
||||||
LLM_KV_GENERAL_AUTHOR,
|
LLM_KV_GENERAL_AUTHOR,
|
||||||
LLM_KV_GENERAL_VERSION,
|
LLM_KV_GENERAL_VERSION,
|
||||||
@ -100,6 +111,7 @@ enum llm_kv {
|
|||||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||||
LLM_KV_EXPERT_GATING_FUNC,
|
LLM_KV_EXPERT_GATING_FUNC,
|
||||||
|
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||||
LLM_KV_POOLING_TYPE,
|
LLM_KV_POOLING_TYPE,
|
||||||
LLM_KV_LOGIT_SCALE,
|
LLM_KV_LOGIT_SCALE,
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
@ -112,6 +124,7 @@ enum llm_kv {
|
|||||||
LLM_KV_RESIDUAL_SCALE,
|
LLM_KV_RESIDUAL_SCALE,
|
||||||
LLM_KV_EMBEDDING_SCALE,
|
LLM_KV_EMBEDDING_SCALE,
|
||||||
LLM_KV_TOKEN_SHIFT_COUNT,
|
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||||
|
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
|
||||||
|
|
||||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||||
@ -126,9 +139,15 @@ enum llm_kv {
|
|||||||
LLM_KV_ATTENTION_CAUSAL,
|
LLM_KV_ATTENTION_CAUSAL,
|
||||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_DECAY_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_ICLR_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
|
||||||
|
LLM_KV_ATTENTION_GATE_LORA_RANK,
|
||||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
|
|
||||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||||
@ -242,6 +261,8 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_ATTN_Q_NORM,
|
LLM_TENSOR_ATTN_Q_NORM,
|
||||||
LLM_TENSOR_ATTN_K_NORM,
|
LLM_TENSOR_ATTN_K_NORM,
|
||||||
LLM_TENSOR_LAYER_OUT_NORM,
|
LLM_TENSOR_LAYER_OUT_NORM,
|
||||||
|
LLM_TENSOR_POST_ATTN_NORM,
|
||||||
|
LLM_TENSOR_POST_MLP_NORM,
|
||||||
LLM_TENSOR_SSM_IN,
|
LLM_TENSOR_SSM_IN,
|
||||||
LLM_TENSOR_SSM_CONV1D,
|
LLM_TENSOR_SSM_CONV1D,
|
||||||
LLM_TENSOR_SSM_X,
|
LLM_TENSOR_SSM_X,
|
||||||
@ -249,8 +270,20 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_SSM_A,
|
LLM_TENSOR_SSM_A,
|
||||||
LLM_TENSOR_SSM_D,
|
LLM_TENSOR_SSM_D,
|
||||||
LLM_TENSOR_SSM_OUT,
|
LLM_TENSOR_SSM_OUT,
|
||||||
|
LLM_TENSOR_TIME_MIX_W0,
|
||||||
LLM_TENSOR_TIME_MIX_W1,
|
LLM_TENSOR_TIME_MIX_W1,
|
||||||
LLM_TENSOR_TIME_MIX_W2,
|
LLM_TENSOR_TIME_MIX_W2,
|
||||||
|
LLM_TENSOR_TIME_MIX_A0,
|
||||||
|
LLM_TENSOR_TIME_MIX_A1,
|
||||||
|
LLM_TENSOR_TIME_MIX_A2,
|
||||||
|
LLM_TENSOR_TIME_MIX_V0,
|
||||||
|
LLM_TENSOR_TIME_MIX_V1,
|
||||||
|
LLM_TENSOR_TIME_MIX_V2,
|
||||||
|
LLM_TENSOR_TIME_MIX_G1,
|
||||||
|
LLM_TENSOR_TIME_MIX_G2,
|
||||||
|
LLM_TENSOR_TIME_MIX_K_K,
|
||||||
|
LLM_TENSOR_TIME_MIX_K_A,
|
||||||
|
LLM_TENSOR_TIME_MIX_R_K,
|
||||||
LLM_TENSOR_TIME_MIX_LERP_X,
|
LLM_TENSOR_TIME_MIX_LERP_X,
|
||||||
LLM_TENSOR_TIME_MIX_LERP_W,
|
LLM_TENSOR_TIME_MIX_LERP_W,
|
||||||
LLM_TENSOR_TIME_MIX_LERP_K,
|
LLM_TENSOR_TIME_MIX_LERP_K,
|
||||||
@ -277,6 +310,8 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_ATTN_Q_B,
|
LLM_TENSOR_ATTN_Q_B,
|
||||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||||
LLM_TENSOR_ATTN_KV_B,
|
LLM_TENSOR_ATTN_KV_B,
|
||||||
|
LLM_TENSOR_ATTN_K_B,
|
||||||
|
LLM_TENSOR_ATTN_V_B,
|
||||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||||
LLM_TENSOR_ATTN_SUB_NORM,
|
LLM_TENSOR_ATTN_SUB_NORM,
|
||||||
|
@ -42,9 +42,9 @@ struct llama_sbatch {
|
|||||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||||
|
|
||||||
// sorted indices into the batch
|
// sorted indices into the batch
|
||||||
std::vector<size_t> ids;
|
std::vector<int64_t> ids;
|
||||||
// batch indices of the output
|
// batch indices of the output
|
||||||
std::vector<size_t> out_ids;
|
std::vector<int64_t> out_ids;
|
||||||
std::vector<llama_sbatch_seq> seq;
|
std::vector<llama_sbatch_seq> seq;
|
||||||
|
|
||||||
const llama_batch * batch = nullptr;
|
const llama_batch * batch = nullptr;
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#if __cplusplus >= 202000L
|
#if __cplusplus >= 202000L
|
||||||
#define LU8(x) (const char*)(u8##x)
|
#define LU8(x) (const char*)(u8##x)
|
||||||
@ -49,8 +50,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||||||
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
{ "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 },
|
||||||
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
{ "command-r", LLM_CHAT_TEMPLATE_COMMAND_R },
|
||||||
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
{ "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 },
|
||||||
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 },
|
{ "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 },
|
||||||
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 },
|
{ "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 },
|
||||||
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
|
||||||
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
|
||||||
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
|
||||||
@ -58,6 +59,10 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
||||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||||
|
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||||
|
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||||
|
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||||
|
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||||
};
|
};
|
||||||
|
|
||||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||||
@ -77,7 +82,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
if (tmpl_contains("<|im_start|>")) {
|
if (tmpl_contains("<|im_start|>")) {
|
||||||
return tmpl_contains("<|im_sep|>")
|
return tmpl_contains("<|im_sep|>")
|
||||||
? LLM_CHAT_TEMPLATE_PHI_4
|
? LLM_CHAT_TEMPLATE_PHI_4
|
||||||
: LLM_CHAT_TEMPLATE_CHATML;
|
: tmpl_contains("<end_of_utterance>")
|
||||||
|
? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
|
||||||
|
: LLM_CHAT_TEMPLATE_CHATML;
|
||||||
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
||||||
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
||||||
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
||||||
@ -115,8 +122,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
}
|
}
|
||||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
|
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_PHI_3;
|
return LLM_CHAT_TEMPLATE_PHI_3;
|
||||||
|
} else if (tmpl_contains("[gMASK]<sop>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_CHATGLM_4;
|
||||||
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
} else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
|
||||||
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||||
|
} else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_GLMEDGE;
|
||||||
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
} else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
return LLM_CHAT_TEMPLATE_ZEPHYR;
|
||||||
} else if (tmpl_contains("bos_token + message['role']")) {
|
} else if (tmpl_contains("bos_token + message['role']")) {
|
||||||
@ -145,9 +156,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
return LLM_CHAT_TEMPLATE_LLAMA_3;
|
return LLM_CHAT_TEMPLATE_LLAMA_3;
|
||||||
} else if (tmpl_contains("[gMASK]sop")) {
|
} else if (tmpl_contains("[gMASK]sop")) {
|
||||||
// chatglm3-6b
|
// chatglm3-6b
|
||||||
return LLM_CHAT_TEMPLATE_CHATGML_3;
|
return LLM_CHAT_TEMPLATE_CHATGLM_3;
|
||||||
} else if (tmpl_contains("[gMASK]<sop>")) {
|
|
||||||
return LLM_CHAT_TEMPLATE_CHATGML_4;
|
|
||||||
} else if (tmpl_contains(LU8("<用户>"))) {
|
} else if (tmpl_contains(LU8("<用户>"))) {
|
||||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||||
return LLM_CHAT_TEMPLATE_MINICPM;
|
return LLM_CHAT_TEMPLATE_MINICPM;
|
||||||
@ -167,6 +176,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||||||
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
||||||
} else if (tmpl_contains("<|role_start|>")) {
|
} else if (tmpl_contains("<|role_start|>")) {
|
||||||
return LLM_CHAT_TEMPLATE_MEGREZ;
|
return LLM_CHAT_TEMPLATE_MEGREZ;
|
||||||
|
} else if (tmpl_contains(" Ассистент:")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_YANDEX;
|
||||||
|
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_BAILING;
|
||||||
|
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||||
}
|
}
|
||||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||||
}
|
}
|
||||||
@ -422,7 +437,7 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) {
|
||||||
// chatglm3-6b
|
// chatglm3-6b
|
||||||
ss << "[gMASK]" << "sop";
|
ss << "[gMASK]" << "sop";
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
@ -432,7 +447,7 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|assistant|>";
|
ss << "<|assistant|>";
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||||
ss << "[gMASK]" << "<sop>";
|
ss << "[gMASK]" << "<sop>";
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string role(message->role);
|
std::string role(message->role);
|
||||||
@ -441,14 +456,6 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|assistant|>";
|
ss << "<|assistant|>";
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
|
||||||
for (auto message : chat) {
|
|
||||||
std::string role(message->role);
|
|
||||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
|
||||||
}
|
|
||||||
if (add_ass) {
|
|
||||||
ss << "<|assistant|>";
|
|
||||||
}
|
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) {
|
||||||
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
@ -566,6 +573,66 @@ int32_t llm_chat_apply_template(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|role_start|>assistant<|role_end|>";
|
ss << "<|role_start|>assistant<|role_end|>";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
|
||||||
|
// Yandex template ("\n\n" is defined as EOT token)
|
||||||
|
|
||||||
|
ss << "<s>";
|
||||||
|
|
||||||
|
for (size_t i = 0; i < chat.size(); i++) {
|
||||||
|
std::string role(chat[i]->role);
|
||||||
|
if (role == "user") {
|
||||||
|
ss << " Пользователь: " << chat[i]->content << "\n\n";
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << " Ассистент: " << chat[i]->content << "\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add generation prompt if needed
|
||||||
|
if (add_ass) {
|
||||||
|
ss << " Ассистент:[SEP]";
|
||||||
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
|
||||||
|
// Bailing (Ling) template
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
|
||||||
|
if (role == "user") {
|
||||||
|
role = "HUMAN";
|
||||||
|
} else {
|
||||||
|
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss << "<role>" << role << "</role>" << message->content;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<role>ASSISTANT</role>";
|
||||||
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) {
|
||||||
|
// Llama 4
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>";
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<|header_start|>assistant<|header_end|>\n\n";
|
||||||
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
|
||||||
|
// SmolVLM
|
||||||
|
ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << message->content << "\n\n";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << "User: " << message->content << "<end_of_utterance>\n";
|
||||||
|
} else {
|
||||||
|
ss << "Assistant: " << message->content << "<end_of_utterance>\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "Assistant:";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
@ -584,4 +651,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
|
|||||||
}
|
}
|
||||||
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,8 +29,8 @@ enum llm_chat_template {
|
|||||||
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
LLM_CHAT_TEMPLATE_DEEPSEEK_3,
|
||||||
LLM_CHAT_TEMPLATE_COMMAND_R,
|
LLM_CHAT_TEMPLATE_COMMAND_R,
|
||||||
LLM_CHAT_TEMPLATE_LLAMA_3,
|
LLM_CHAT_TEMPLATE_LLAMA_3,
|
||||||
LLM_CHAT_TEMPLATE_CHATGML_3,
|
LLM_CHAT_TEMPLATE_CHATGLM_3,
|
||||||
LLM_CHAT_TEMPLATE_CHATGML_4,
|
LLM_CHAT_TEMPLATE_CHATGLM_4,
|
||||||
LLM_CHAT_TEMPLATE_GLMEDGE,
|
LLM_CHAT_TEMPLATE_GLMEDGE,
|
||||||
LLM_CHAT_TEMPLATE_MINICPM,
|
LLM_CHAT_TEMPLATE_MINICPM,
|
||||||
LLM_CHAT_TEMPLATE_EXAONE_3,
|
LLM_CHAT_TEMPLATE_EXAONE_3,
|
||||||
@ -38,6 +38,10 @@ enum llm_chat_template {
|
|||||||
LLM_CHAT_TEMPLATE_GRANITE,
|
LLM_CHAT_TEMPLATE_GRANITE,
|
||||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||||
|
LLM_CHAT_TEMPLATE_YANDEX,
|
||||||
|
LLM_CHAT_TEMPLATE_BAILING,
|
||||||
|
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||||
|
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -3,66 +3,212 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
#include "llama-model.h"
|
#include "llama-graph.h"
|
||||||
#include "llama-kv-cache.h"
|
|
||||||
#include "llama-adapter.h"
|
#include "llama-adapter.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <set>
|
|
||||||
|
struct llama_model;
|
||||||
|
struct llama_kv_cache;
|
||||||
|
|
||||||
|
class llama_io_read_i;
|
||||||
|
class llama_io_write_i;
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context {
|
||||||
llama_context(const llama_model & model)
|
// init scheduler and compute buffers, reserve worst-case graphs
|
||||||
: model(model)
|
llama_context(
|
||||||
, t_start_us(model.t_start_us)
|
const llama_model & model,
|
||||||
, t_load_us(model.t_load_us) {}
|
llama_context_params params);
|
||||||
|
|
||||||
const struct llama_model & model;
|
~llama_context();
|
||||||
|
|
||||||
struct llama_cparams cparams;
|
void synchronize();
|
||||||
struct llama_sbatch sbatch; // TODO: revisit if needed
|
|
||||||
struct llama_kv_cache kv_self;
|
|
||||||
struct llama_adapter_cvec cvec;
|
|
||||||
|
|
||||||
std::unordered_map<struct llama_adapter_lora *, float> lora;
|
const llama_model & get_model() const;
|
||||||
|
|
||||||
std::vector<ggml_backend_ptr> backends;
|
uint32_t n_ctx() const;
|
||||||
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
uint32_t n_ctx_per_seq() const;
|
||||||
|
uint32_t n_batch() const;
|
||||||
|
uint32_t n_ubatch() const;
|
||||||
|
uint32_t n_seq_max() const;
|
||||||
|
|
||||||
ggml_backend_t backend_cpu = nullptr;
|
uint32_t n_threads() const;
|
||||||
|
uint32_t n_threads_batch() const;
|
||||||
|
|
||||||
ggml_threadpool_t threadpool = nullptr;
|
llama_kv_cache * get_kv_self();
|
||||||
ggml_threadpool_t threadpool_batch = nullptr;
|
const llama_kv_cache * get_kv_self() const;
|
||||||
|
|
||||||
bool has_evaluated_once = false;
|
void kv_self_update();
|
||||||
|
|
||||||
mutable int64_t t_start_us;
|
enum llama_pooling_type pooling_type() const;
|
||||||
mutable int64_t t_load_us;
|
|
||||||
mutable int64_t t_p_eval_us = 0;
|
|
||||||
mutable int64_t t_eval_us = 0;
|
|
||||||
|
|
||||||
mutable int64_t t_compute_start_us = 0;
|
float * get_logits();
|
||||||
mutable int64_t n_queued_tokens = 0;
|
float * get_logits_ith(int32_t i);
|
||||||
|
|
||||||
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
float * get_embeddings();
|
||||||
mutable int32_t n_eval = 0; // number of eval calls
|
float * get_embeddings_ith(int32_t i);
|
||||||
|
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||||
|
|
||||||
// host buffer for the model output (logits and embeddings)
|
void attach_threadpool(
|
||||||
ggml_backend_buffer_ptr buf_output;
|
ggml_threadpool_t threadpool,
|
||||||
|
ggml_threadpool_t threadpool_batch);
|
||||||
|
|
||||||
|
void detach_threadpool();
|
||||||
|
|
||||||
|
void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
|
||||||
|
|
||||||
|
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
||||||
|
|
||||||
|
void set_embeddings (bool value);
|
||||||
|
void set_causal_attn(bool value);
|
||||||
|
void set_warmup(bool value);
|
||||||
|
|
||||||
|
void set_adapter_lora(
|
||||||
|
llama_adapter_lora * adapter,
|
||||||
|
float scale);
|
||||||
|
|
||||||
|
bool rm_adapter_lora(
|
||||||
|
llama_adapter_lora * adapter);
|
||||||
|
|
||||||
|
void clear_adapter_lora();
|
||||||
|
|
||||||
|
bool apply_adapter_cvec(
|
||||||
|
const float * data,
|
||||||
|
size_t len,
|
||||||
|
int32_t n_embd,
|
||||||
|
int32_t il_start,
|
||||||
|
int32_t il_end);
|
||||||
|
|
||||||
|
int encode(llama_batch & inp_batch);
|
||||||
|
int decode(llama_batch & inp_batch);
|
||||||
|
|
||||||
|
//
|
||||||
|
// state save/load
|
||||||
|
//
|
||||||
|
|
||||||
|
size_t state_get_size();
|
||||||
|
size_t state_get_data( uint8_t * dst, size_t size);
|
||||||
|
size_t state_set_data(const uint8_t * src, size_t size);
|
||||||
|
|
||||||
|
size_t state_seq_get_size(llama_seq_id seq_id);
|
||||||
|
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
||||||
|
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
||||||
|
|
||||||
|
bool state_load_file(
|
||||||
|
const char * filepath,
|
||||||
|
llama_token * tokens_out,
|
||||||
|
size_t n_token_capacity,
|
||||||
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
|
bool state_save_file(
|
||||||
|
const char * filepath,
|
||||||
|
const llama_token * tokens,
|
||||||
|
size_t n_token_count);
|
||||||
|
|
||||||
|
size_t state_seq_load_file(
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
const char * filepath,
|
||||||
|
llama_token * tokens_out,
|
||||||
|
size_t n_token_capacity,
|
||||||
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
|
size_t state_seq_save_file(
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
const char * filepath,
|
||||||
|
const llama_token * tokens,
|
||||||
|
size_t n_token_count);
|
||||||
|
|
||||||
|
//
|
||||||
|
// perf
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_perf_context_data perf_get_data() const;
|
||||||
|
void perf_reset();
|
||||||
|
|
||||||
|
private:
|
||||||
|
//
|
||||||
|
// output
|
||||||
|
//
|
||||||
|
|
||||||
|
// Make sure enough space is available for outputs.
|
||||||
|
// Returns max number of outputs for which space was reserved.
|
||||||
|
int32_t output_reserve(int32_t n_outputs);
|
||||||
|
|
||||||
|
// make the outputs have the same order they had in the user-provided batch
|
||||||
|
// TODO: maybe remove this
|
||||||
|
void output_reorder();
|
||||||
|
|
||||||
|
//
|
||||||
|
// graph
|
||||||
|
//
|
||||||
|
|
||||||
|
int32_t graph_max_nodes() const;
|
||||||
|
|
||||||
|
// zero-out inputs and create the ctx_compute for the compute graph
|
||||||
|
ggml_cgraph * graph_init();
|
||||||
|
|
||||||
|
llm_graph_result_ptr graph_build(
|
||||||
|
ggml_context * ctx,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
llm_graph_type gtype);
|
||||||
|
|
||||||
|
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||||
|
ggml_status graph_compute(
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
bool batched);
|
||||||
|
|
||||||
|
llm_graph_cb graph_get_cb() const;
|
||||||
|
|
||||||
|
// used by kv_self_update()
|
||||||
|
ggml_tensor * build_rope_shift(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * shift,
|
||||||
|
ggml_tensor * factors,
|
||||||
|
float freq_base,
|
||||||
|
float freq_scale) const;
|
||||||
|
|
||||||
|
llm_graph_result_ptr build_kv_self_shift(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * gf) const;
|
||||||
|
|
||||||
|
llm_graph_result_ptr build_kv_self_defrag(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * gf) const;
|
||||||
|
|
||||||
|
// TODO: read/write lora adapters and cvec
|
||||||
|
size_t state_write_data(llama_io_write_i & io);
|
||||||
|
size_t state_read_data (llama_io_read_i & io);
|
||||||
|
|
||||||
|
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||||
|
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
||||||
|
|
||||||
|
//
|
||||||
|
// members
|
||||||
|
//
|
||||||
|
|
||||||
|
const llama_model & model;
|
||||||
|
|
||||||
|
llama_cparams cparams;
|
||||||
|
llama_adapter_cvec cvec;
|
||||||
|
llama_adapter_loras loras;
|
||||||
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
|
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||||
|
|
||||||
|
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
||||||
|
|
||||||
|
// TODO: remove
|
||||||
|
bool logits_all = false;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||||
size_t logits_size = 0; // capacity (of floats) for logits
|
size_t logits_size = 0; // capacity (of floats) for logits
|
||||||
float * logits = nullptr;
|
float * logits = nullptr;
|
||||||
|
|
||||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
|
||||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
|
||||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
|
||||||
|
|
||||||
bool logits_all = false;
|
|
||||||
|
|
||||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||||
@ -72,57 +218,47 @@ struct llama_context {
|
|||||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||||
|
|
||||||
// whether we are computing encoder output or decoder output
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||||
bool is_encoding = false;
|
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
|
||||||
|
|
||||||
// TODO: find a better way to accommodate mutli-dimension position encoding methods
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||||
// number of position id each token get, 1 for each token in most cases.
|
|
||||||
// when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
|
|
||||||
int n_pos_per_token = 1;
|
|
||||||
|
|
||||||
// output of the encoder part of the encoder-decoder models
|
|
||||||
std::vector<float> embd_enc;
|
|
||||||
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
|
||||||
|
|
||||||
// memory buffers used to evaluate the model
|
|
||||||
std::vector<uint8_t> buf_compute_meta;
|
|
||||||
ggml_backend_sched_ptr sched;
|
ggml_backend_sched_ptr sched;
|
||||||
|
|
||||||
|
ggml_backend_t backend_cpu = nullptr;
|
||||||
|
std::vector<ggml_backend_ptr> backends;
|
||||||
|
|
||||||
|
ggml_context_ptr ctx_compute;
|
||||||
|
|
||||||
|
ggml_threadpool_t threadpool = nullptr;
|
||||||
|
ggml_threadpool_t threadpool_batch = nullptr;
|
||||||
|
|
||||||
ggml_abort_callback abort_callback = nullptr;
|
ggml_abort_callback abort_callback = nullptr;
|
||||||
void * abort_callback_data = nullptr;
|
void * abort_callback_data = nullptr;
|
||||||
|
|
||||||
// input tensors
|
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
||||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
|
||||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
// buffer types used for the compute buffer of each backend
|
||||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
std::vector<ggml_backend_t> backend_ptrs;
|
||||||
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
||||||
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
|
||||||
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
|
// memory buffers used to evaluate the model
|
||||||
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
std::vector<uint8_t> buf_compute_meta;
|
||||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
|
||||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
// host buffer for the model output (logits and embeddings)
|
||||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
ggml_backend_buffer_ptr buf_output;
|
||||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
|
||||||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
bool has_evaluated_once = false;
|
||||||
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
|
||||||
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
// perf
|
||||||
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
mutable int64_t t_start_us = 0;
|
||||||
|
mutable int64_t t_load_us = 0;
|
||||||
|
mutable int64_t t_p_eval_us = 0;
|
||||||
|
mutable int64_t t_eval_us = 0;
|
||||||
|
|
||||||
|
mutable int64_t t_compute_start_us = 0;
|
||||||
|
mutable int64_t n_queued_tokens = 0;
|
||||||
|
|
||||||
|
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||||
|
mutable int32_t n_eval = 0; // number of eval calls
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: make these methods of llama_context
|
|
||||||
void llama_set_k_shift(struct llama_context & lctx);
|
|
||||||
|
|
||||||
void llama_set_s_copy(struct llama_context & lctx);
|
|
||||||
|
|
||||||
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
|
|
||||||
|
|
||||||
// Make sure enough space is available for outputs.
|
|
||||||
// Returns max number of outputs for which space was reserved.
|
|
||||||
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
|
|
||||||
|
|
||||||
// make the outputs have the same order they had in the user-provided batch
|
|
||||||
void llama_output_reorder(struct llama_context & ctx);
|
|
||||||
|
|
||||||
// For internal test use
|
|
||||||
// TODO: remove
|
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|
|
||||||
|
@ -29,6 +29,7 @@ struct llama_cparams {
|
|||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
bool flash_attn;
|
bool flash_attn;
|
||||||
bool no_perf;
|
bool no_perf;
|
||||||
|
bool warmup;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
|
|||||||
size_t last_sym_start = rule.size();
|
size_t last_sym_start = rule.size();
|
||||||
const char * pos = src;
|
const char * pos = src;
|
||||||
|
|
||||||
auto handle_repetitions = [&](int min_times, int max_times) {
|
auto handle_repetitions = [&](int min_times, int max_times) {
|
||||||
|
|
||||||
if (last_sym_start == rule.size()) {
|
if (last_sym_start == rule.size()) {
|
||||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||||
// the following rewrite rules:
|
// the following rewrite rules:
|
||||||
// S{m,n} --> S S S (m times) S'(n-m)
|
// S{m,n} --> S S S (m times) S'(n-m)
|
||||||
// S'(x) ::= S S'(x-1) |
|
// S'(x) ::= S S'(x-1) |
|
||||||
// (... n-m definitions of these S' rules ...)
|
// (... n-m definitions of these S' rules ...)
|
||||||
// S'(1) ::= S |
|
// S'(1) ::= S |
|
||||||
// S{m,} --> S S S (m times) S'
|
// S{m,} --> S S S (m times) S'
|
||||||
// S' ::= S S' |
|
// S' ::= S S' |
|
||||||
// S* --> S{0,}
|
// S* --> S{0,}
|
||||||
// --> S' ::= S S' |
|
// --> S' ::= S S' |
|
||||||
// S+ --> S{1,}
|
// S+ --> S{1,}
|
||||||
// --> S S'
|
// --> S S'
|
||||||
// S' ::= S S' |
|
// S' ::= S S' |
|
||||||
// S? --> S{0,1}
|
// S? --> S{0,1}
|
||||||
// --> S'
|
// --> S'
|
||||||
// S' ::= S |
|
// S' ::= S |
|
||||||
|
|
||||||
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
||||||
if (min_times == 0) {
|
if (min_times == 0) {
|
||||||
rule.resize(last_sym_start);
|
rule.resize(last_sym_start);
|
||||||
} else {
|
} else {
|
||||||
// Repeat the previous elements (min_times - 1) times
|
// Repeat the previous elements (min_times - 1) times
|
||||||
for (int i = 1; i < min_times; i++) {
|
for (int i = 1; i < min_times; i++) {
|
||||||
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t last_rec_rule_id = 0;
|
|
||||||
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
|
||||||
|
|
||||||
llama_grammar_rule rec_rule(prev_rule);
|
|
||||||
for (int i = 0; i < n_opt; i++) {
|
|
||||||
rec_rule.resize(prev_rule.size());
|
|
||||||
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
|
||||||
if (i > 0 || max_times < 0) {
|
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
|
||||||
}
|
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
|
||||||
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
|
||||||
add_rule( rec_rule_id, rec_rule);
|
|
||||||
last_rec_rule_id = rec_rule_id;
|
|
||||||
}
|
|
||||||
if (n_opt > 0) {
|
|
||||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
while (*pos) {
|
|
||||||
if (*pos == '"') { // literal string
|
|
||||||
pos++;
|
|
||||||
last_sym_start = rule.size();
|
|
||||||
while (*pos != '"') {
|
|
||||||
if (!*pos) {
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
auto char_pair = parse_char(pos);
|
|
||||||
pos = char_pair.second;
|
|
||||||
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '[') { // char range(s)
|
|
||||||
pos++;
|
|
||||||
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
|
||||||
if (*pos == '^') {
|
|
||||||
pos++;
|
|
||||||
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
|
||||||
}
|
|
||||||
last_sym_start = rule.size();
|
|
||||||
while (*pos != ']') {
|
|
||||||
if (!*pos) {
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
auto char_pair = parse_char(pos);
|
|
||||||
pos = char_pair.second;
|
|
||||||
enum llama_gretype type = last_sym_start < rule.size()
|
|
||||||
? LLAMA_GRETYPE_CHAR_ALT
|
|
||||||
: start_type;
|
|
||||||
|
|
||||||
rule.push_back({type, char_pair.first});
|
|
||||||
if (pos[0] == '-' && pos[1] != ']') {
|
|
||||||
if (!pos[1]) {
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
auto endchar_pair = parse_char(pos + 1);
|
|
||||||
pos = endchar_pair.second;
|
|
||||||
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
|
||||||
const char * name_end = parse_name(pos);
|
|
||||||
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
|
||||||
pos = parse_space(name_end, is_nested);
|
|
||||||
last_sym_start = rule.size();
|
|
||||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
|
||||||
} else if (*pos == '(') { // grouping
|
|
||||||
// parse nested alternates into synthesized rule
|
|
||||||
pos = parse_space(pos + 1, true);
|
|
||||||
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
|
||||||
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
|
||||||
last_sym_start = rule.size();
|
|
||||||
// output reference to synthesized rule
|
|
||||||
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
|
||||||
if (*pos != ')') {
|
|
||||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '.') { // any char
|
|
||||||
last_sym_start = rule.size();
|
|
||||||
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '*') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
handle_repetitions(0, -1);
|
|
||||||
} else if (*pos == '+') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
handle_repetitions(1, -1);
|
|
||||||
} else if (*pos == '?') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
handle_repetitions(0, 1);
|
|
||||||
} else if (*pos == '{') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
|
|
||||||
if (!is_digit_char(*pos)) {
|
|
||||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
|
||||||
}
|
|
||||||
const char * int_end = parse_int(pos);
|
|
||||||
int min_times = std::stoul(std::string(pos, int_end - pos));
|
|
||||||
pos = parse_space(int_end, is_nested);
|
|
||||||
|
|
||||||
int max_times = -1;
|
|
||||||
|
|
||||||
if (*pos == '}') {
|
|
||||||
max_times = min_times;
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == ',') {
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
|
|
||||||
if (is_digit_char(*pos)) {
|
|
||||||
const char * int_end = parse_int(pos);
|
|
||||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
|
||||||
pos = parse_space(int_end, is_nested);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (*pos != '}') {
|
|
||||||
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
|
||||||
}
|
|
||||||
handle_repetitions(min_times, max_times);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return pos;
|
|
||||||
|
uint32_t last_rec_rule_id = 0;
|
||||||
|
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
||||||
|
|
||||||
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
|
for (int i = 0; i < n_opt; i++) {
|
||||||
|
rec_rule.resize(prev_rule.size());
|
||||||
|
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||||
|
if (i > 0 || max_times < 0) {
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
||||||
|
}
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule( rec_rule_id, rec_rule);
|
||||||
|
last_rec_rule_id = rec_rule_id;
|
||||||
|
}
|
||||||
|
if (n_opt > 0) {
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
while (*pos) {
|
||||||
|
if (*pos == '"') { // literal string
|
||||||
|
pos++;
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
while (*pos != '"') {
|
||||||
|
if (!*pos) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '[') { // char range(s)
|
||||||
|
pos++;
|
||||||
|
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||||
|
if (*pos == '^') {
|
||||||
|
pos++;
|
||||||
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
|
}
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
while (*pos != ']') {
|
||||||
|
if (!*pos) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
enum llama_gretype type = last_sym_start < rule.size()
|
||||||
|
? LLAMA_GRETYPE_CHAR_ALT
|
||||||
|
: start_type;
|
||||||
|
|
||||||
|
rule.push_back({type, char_pair.first});
|
||||||
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
|
if (!pos[1]) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto endchar_pair = parse_char(pos + 1);
|
||||||
|
pos = endchar_pair.second;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
|
const char * name_end = parse_name(pos);
|
||||||
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
|
pos = parse_space(name_end, is_nested);
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
|
} else if (*pos == '(') { // grouping
|
||||||
|
// parse nested alternates into synthesized rule
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
||||||
|
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
// output reference to synthesized rule
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
if (*pos != ')') {
|
||||||
|
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '.') { // any char
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '*') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(0, -1);
|
||||||
|
} else if (*pos == '+') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(1, -1);
|
||||||
|
} else if (*pos == '?') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(0, 1);
|
||||||
|
} else if (*pos == '{') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
|
||||||
|
if (!is_digit_char(*pos)) {
|
||||||
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||||
|
}
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
int min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
|
int max_times = -1;
|
||||||
|
|
||||||
|
if (*pos == '}') {
|
||||||
|
max_times = min_times;
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == ',') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
|
||||||
|
if (is_digit_char(*pos)) {
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = parse_space(int_end, is_nested);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*pos != '}') {
|
||||||
|
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
|
}
|
||||||
|
handle_repetitions(min_times, max_times);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
const char * llama_grammar_parser::parse_rule(const char * src) {
|
const char * llama_grammar_parser::parse_rule(const char * src) {
|
||||||
const char * name_end = parse_name(src);
|
const char * name_end = parse_name(src);
|
||||||
const char * pos = parse_space(name_end, false);
|
const char * pos = parse_space(name_end, false);
|
||||||
size_t name_len = name_end - src;
|
size_t name_len = name_end - src;
|
||||||
uint32_t rule_id = get_symbol_id(src, name_len);
|
uint32_t rule_id = get_symbol_id(src, name_len);
|
||||||
const std::string name(src, name_len);
|
const std::string name(src, name_len);
|
||||||
|
|
||||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||||
}
|
|
||||||
pos = parse_space(pos + 3, true);
|
|
||||||
|
|
||||||
pos = parse_alternates(pos, name, rule_id, false);
|
|
||||||
|
|
||||||
if (*pos == '\r') {
|
|
||||||
pos += pos[1] == '\n' ? 2 : 1;
|
|
||||||
} else if (*pos == '\n') {
|
|
||||||
pos++;
|
|
||||||
} else if (*pos) {
|
|
||||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
|
||||||
}
|
|
||||||
return parse_space(pos, true);
|
|
||||||
}
|
}
|
||||||
|
pos = parse_space(pos + 3, true);
|
||||||
|
|
||||||
|
pos = parse_alternates(pos, name, rule_id, false);
|
||||||
|
|
||||||
|
if (*pos == '\r') {
|
||||||
|
pos += pos[1] == '\n' ? 2 : 1;
|
||||||
|
} else if (*pos == '\n') {
|
||||||
|
pos++;
|
||||||
|
} else if (*pos) {
|
||||||
|
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||||
|
}
|
||||||
|
return parse_space(pos, true);
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_grammar_parser::parse(const char * src) {
|
bool llama_grammar_parser::parse(const char * src) {
|
||||||
try {
|
try {
|
||||||
@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
/* .awaiting_trigger = */ false,
|
/* .awaiting_trigger = */ false,
|
||||||
/* .trigger_buffer = */ "",
|
/* .trigger_buffer = */ "",
|
||||||
/* .trigger_tokens = */ {},
|
/* .trigger_tokens = */ {},
|
||||||
/* .trigger_words = */ {},
|
/* .trigger_patterns = */ {},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
bool lazy,
|
bool lazy,
|
||||||
const char ** trigger_words,
|
const char ** trigger_patterns,
|
||||||
size_t num_trigger_words,
|
size_t num_trigger_patterns,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens) {
|
size_t num_trigger_tokens) {
|
||||||
llama_grammar_parser parser;
|
llama_grammar_parser parser;
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
if (!parser.parse(grammar_str)) {
|
// rules will be empty (default) if there are parse errors
|
||||||
return nullptr;
|
if (!parser.parse(grammar_str) || parser.rules.empty()) {
|
||||||
}
|
|
||||||
|
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (parser.rules.empty()) {
|
|
||||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -1054,14 +1050,16 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
} while (true);
|
} while (true);
|
||||||
|
|
||||||
std::vector<llama_token> vec_trigger_tokens;
|
std::vector<llama_token> vec_trigger_tokens;
|
||||||
std::vector<std::string> vec_trigger_words;
|
std::vector<llama_grammar_trigger_pattern> vec_trigger_patterns;
|
||||||
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
for (size_t i = 0; i < num_trigger_tokens; i++) {
|
||||||
GGML_ASSERT(trigger_tokens != nullptr);
|
GGML_ASSERT(trigger_tokens != nullptr);
|
||||||
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
vec_trigger_tokens.push_back(trigger_tokens[i]);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < num_trigger_words; i++) {
|
for (size_t i = 0; i < num_trigger_patterns; i++) {
|
||||||
GGML_ASSERT(trigger_words != nullptr);
|
GGML_ASSERT(trigger_patterns != nullptr);
|
||||||
vec_trigger_words.push_back(trigger_words[i]);
|
auto & trigger = vec_trigger_patterns.emplace_back();
|
||||||
|
trigger.pattern = trigger_patterns[i];
|
||||||
|
trigger.regex = std::regex(trigger.pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
@ -1076,7 +1074,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
/* .awaiting_trigger = */ lazy,
|
/* .awaiting_trigger = */ lazy,
|
||||||
/* .trigger_buffer = */ "",
|
/* .trigger_buffer = */ "",
|
||||||
std::move(vec_trigger_tokens),
|
std::move(vec_trigger_tokens),
|
||||||
std::move(vec_trigger_words),
|
std::move(vec_trigger_patterns),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1089,7 +1087,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||||
llama_grammar * result = new llama_grammar {
|
auto * result = new llama_grammar {
|
||||||
grammar.vocab,
|
grammar.vocab,
|
||||||
grammar.rules,
|
grammar.rules,
|
||||||
grammar.stacks,
|
grammar.stacks,
|
||||||
@ -1098,7 +1096,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|||||||
grammar.awaiting_trigger,
|
grammar.awaiting_trigger,
|
||||||
grammar.trigger_buffer,
|
grammar.trigger_buffer,
|
||||||
grammar.trigger_tokens,
|
grammar.trigger_tokens,
|
||||||
grammar.trigger_words,
|
grammar.trigger_patterns,
|
||||||
};
|
};
|
||||||
|
|
||||||
// redirect elements in stacks to point to new rules
|
// redirect elements in stacks to point to new rules
|
||||||
@ -1173,20 +1171,22 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
// TODO: consider a smarter incremental substring search algorithm (store last position to search from).
|
|
||||||
grammar.trigger_buffer += piece;
|
grammar.trigger_buffer += piece;
|
||||||
for (const auto & word : grammar.trigger_words) {
|
|
||||||
auto pos = grammar.trigger_buffer.find(word);
|
std::smatch match;
|
||||||
if (pos != std::string::npos) {
|
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||||
|
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
auto constrained_str = grammar.trigger_buffer.substr(pos);
|
// get from the first match to the end of the string
|
||||||
|
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
||||||
|
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
llama_grammar_accept_str(grammar, constrained_str);
|
||||||
LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str());
|
LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str());
|
LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -105,6 +106,11 @@ struct llama_grammar_parser {
|
|||||||
void print(FILE * file);
|
void print(FILE * file);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_grammar_trigger_pattern {
|
||||||
|
std::string pattern;
|
||||||
|
std::regex regex;
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
// note: allow null vocab for testing (not great)
|
// note: allow null vocab for testing (not great)
|
||||||
const llama_vocab * vocab;
|
const llama_vocab * vocab;
|
||||||
@ -116,13 +122,16 @@ struct llama_grammar {
|
|||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
|
||||||
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
// lazy grammars wait for trigger words or tokens before constraining the sampling.
|
||||||
// we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
|
||||||
// (useful e.g. for tool_choice=required)
|
// (useful e.g. for tool_choice=required)
|
||||||
bool lazy = false;
|
bool lazy = false;
|
||||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||||
std::vector<std::string> trigger_words;
|
std::vector<llama_grammar_trigger_pattern>
|
||||||
|
trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
|
||||||
|
// string, and the grammar will be given the string from the first match group onwards.
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -141,8 +150,8 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
bool lazy,
|
bool lazy,
|
||||||
const char ** trigger_words,
|
const char ** trigger_patterns,
|
||||||
size_t num_trigger_words,
|
size_t num_trigger_patterns,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens);
|
size_t num_trigger_tokens);
|
||||||
|
|
||||||
|
1732
examples/talk-llama/llama-graph.cpp
Normal file
1732
examples/talk-llama/llama-graph.cpp
Normal file
File diff suppressed because it is too large
Load Diff
594
examples/talk-llama/llama-graph.h
Normal file
594
examples/talk-llama/llama-graph.h
Normal file
@ -0,0 +1,594 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-arch.h"
|
||||||
|
#include "llama-hparams.h"
|
||||||
|
#include "llama-adapter.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
struct ggml_cgraph;
|
||||||
|
struct ggml_context;
|
||||||
|
struct ggml_tensor;
|
||||||
|
|
||||||
|
struct llama_ubatch;
|
||||||
|
struct llama_cparams;
|
||||||
|
|
||||||
|
class llama_memory_i;
|
||||||
|
class llama_kv_cache_unified;
|
||||||
|
|
||||||
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
|
enum llm_graph_type {
|
||||||
|
LLM_GRAPH_TYPE_DEFAULT,
|
||||||
|
LLM_GRAPH_TYPE_ENCODER,
|
||||||
|
LLM_GRAPH_TYPE_DECODER,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llm_ffn_op_type {
|
||||||
|
LLM_FFN_SILU,
|
||||||
|
LLM_FFN_GELU,
|
||||||
|
LLM_FFN_RELU,
|
||||||
|
LLM_FFN_RELU_SQR,
|
||||||
|
LLM_FFN_SWIGLU,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llm_ffn_gate_type {
|
||||||
|
LLM_FFN_SEQ,
|
||||||
|
LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llm_norm_type {
|
||||||
|
LLM_NORM,
|
||||||
|
LLM_NORM_RMS,
|
||||||
|
LLM_NORM_GROUP,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: tmp - need something better to pass the data from the encoder to the decoder
|
||||||
|
struct llama_cross {
|
||||||
|
// the output embeddings from the encoder as a ggml tensor
|
||||||
|
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
|
||||||
|
//ggml_tensor * t_embd = nullptr;
|
||||||
|
|
||||||
|
int64_t n_embd = 0;
|
||||||
|
int64_t n_enc = 0;
|
||||||
|
|
||||||
|
// embeddings data copied to host memory (tmp)
|
||||||
|
std::vector<float> v_embd;
|
||||||
|
|
||||||
|
// needed to construct the cross-attention mask in the decoder
|
||||||
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// llm_graph_input
|
||||||
|
//
|
||||||
|
|
||||||
|
class llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
virtual ~llm_graph_input_i() = default;
|
||||||
|
|
||||||
|
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
||||||
|
|
||||||
|
|
||||||
|
class llm_graph_input_embd : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_embd() = default;
|
||||||
|
virtual ~llm_graph_input_embd() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
||||||
|
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_pos : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
||||||
|
virtual ~llm_graph_input_pos() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
||||||
|
|
||||||
|
const int64_t n_pos_per_embd = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
// temperature tuning, used by llama4
|
||||||
|
class llm_graph_input_attn_temp : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
|
||||||
|
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
|
||||||
|
virtual ~llm_graph_input_attn_temp() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
|
||||||
|
|
||||||
|
const uint32_t n_attn_temp_floor_scale;
|
||||||
|
const float f_attn_temp_scale;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_pos_bucket : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
|
||||||
|
virtual ~llm_graph_input_pos_bucket() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_pos_bucket_kv(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
|
||||||
|
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_kv_cache_unified * kv_self;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_out_ids(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
|
||||||
|
virtual ~llm_graph_input_out_ids() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * out_ids; // I32 [n_outputs]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
|
const int32_t n_outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_mean : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
|
||||||
|
virtual ~llm_graph_input_mean() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
||||||
|
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_cls : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
|
||||||
|
virtual ~llm_graph_input_cls() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * cls; // I32 [n_batch]
|
||||||
|
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||||
|
virtual ~llm_graph_input_s_copy() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
|
const llama_kv_cache_unified * kv_self;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||||
|
virtual ~llm_graph_input_s_mask() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||||
|
|
||||||
|
const llama_kv_cache_unified * kv_self;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_cross_embd(
|
||||||
|
const llama_cross * cross) : cross(cross) {}
|
||||||
|
virtual ~llm_graph_input_cross_embd() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
|
||||||
|
|
||||||
|
const llama_cross * cross;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
|
||||||
|
hparams(hparams),
|
||||||
|
cparams(cparams) {
|
||||||
|
}
|
||||||
|
~llm_graph_input_attn_no_cache() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
||||||
|
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_attn_kv_unified(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_kv_cache_unified * kv_self) :
|
||||||
|
hparams(hparams),
|
||||||
|
cparams(cparams),
|
||||||
|
kv_self(kv_self) {
|
||||||
|
}
|
||||||
|
~llm_graph_input_attn_kv_unified() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
|
||||||
|
const llama_kv_cache_unified * kv_self;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
|
||||||
|
~llm_graph_input_attn_cross() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
||||||
|
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
||||||
|
|
||||||
|
const llama_cross * cross = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// llm_graph_result
|
||||||
|
//
|
||||||
|
|
||||||
|
// these objects deliver the result from the graph build process back to the llama_context
|
||||||
|
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
|
||||||
|
// specific data, by calling the set_inputs() method
|
||||||
|
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
|
||||||
|
// these are used by the llama_context to extact the relevant data, based on the compute parameters
|
||||||
|
|
||||||
|
class llm_graph_result_i {
|
||||||
|
public:
|
||||||
|
virtual ~llm_graph_result_i() = default;
|
||||||
|
|
||||||
|
virtual ggml_tensor * get_logits() = 0;
|
||||||
|
virtual ggml_tensor * get_embd() = 0;
|
||||||
|
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||||
|
|
||||||
|
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
|
||||||
|
|
||||||
|
|
||||||
|
class llm_graph_result : public llm_graph_result_i {
|
||||||
|
public:
|
||||||
|
virtual ~llm_graph_result() = default;
|
||||||
|
|
||||||
|
ggml_tensor * get_logits() override { return t_logits; }
|
||||||
|
ggml_tensor * get_embd() override { return t_embd; }
|
||||||
|
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||||
|
|
||||||
|
void set_inputs(const llama_ubatch * ubatch) override {
|
||||||
|
for (auto & input : inputs) {
|
||||||
|
input->set_input(ubatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
|
||||||
|
inputs.emplace_back(std::move(input));
|
||||||
|
return inputs.back().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// important graph nodes
|
||||||
|
ggml_tensor * t_logits = nullptr;
|
||||||
|
ggml_tensor * t_embd = nullptr;
|
||||||
|
ggml_tensor * t_embd_pooled = nullptr;
|
||||||
|
|
||||||
|
std::vector<llm_graph_input_ptr> inputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// llm_graph_context
|
||||||
|
//
|
||||||
|
|
||||||
|
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
|
||||||
|
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
|
||||||
|
|
||||||
|
struct llm_graph_params {
|
||||||
|
ggml_context * ctx;
|
||||||
|
|
||||||
|
const llm_arch arch;
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
const llama_ubatch & ubatch;
|
||||||
|
|
||||||
|
ggml_backend_sched * sched;
|
||||||
|
ggml_backend * backend_cpu;
|
||||||
|
|
||||||
|
const llama_adapter_cvec * cvec;
|
||||||
|
const llama_adapter_loras * loras;
|
||||||
|
const llama_memory_i * memory;
|
||||||
|
const llama_cross * cross;
|
||||||
|
|
||||||
|
int32_t n_outputs;
|
||||||
|
|
||||||
|
const llm_graph_cb & cb;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_graph_context {
|
||||||
|
const llm_arch arch;
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
const llama_cparams & cparams;
|
||||||
|
const llama_ubatch & ubatch;
|
||||||
|
|
||||||
|
const int64_t n_embd;
|
||||||
|
const int64_t n_layer;
|
||||||
|
const int64_t n_rot;
|
||||||
|
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
||||||
|
const int64_t n_ctx_per_seq;
|
||||||
|
const int64_t n_head;
|
||||||
|
const int64_t n_head_kv;
|
||||||
|
const int64_t n_embd_head_k;
|
||||||
|
const int64_t n_embd_k_gqa;
|
||||||
|
const int64_t n_embd_head_v;
|
||||||
|
const int64_t n_embd_v_gqa;
|
||||||
|
const int64_t n_expert;
|
||||||
|
const int64_t n_expert_used;
|
||||||
|
|
||||||
|
const float freq_base;
|
||||||
|
const float freq_scale;
|
||||||
|
const float ext_factor;
|
||||||
|
const float attn_factor;
|
||||||
|
const float beta_fast;
|
||||||
|
const float beta_slow;
|
||||||
|
const float norm_eps;
|
||||||
|
const float norm_rms_eps;
|
||||||
|
|
||||||
|
const int32_t n_tokens;
|
||||||
|
const int32_t n_outputs;
|
||||||
|
const int32_t n_ctx_orig; // yarn
|
||||||
|
|
||||||
|
const enum llama_pooling_type pooling_type;
|
||||||
|
const enum llama_rope_type rope_type;
|
||||||
|
|
||||||
|
ggml_context * ctx0 = nullptr;
|
||||||
|
|
||||||
|
ggml_backend_sched * sched;
|
||||||
|
|
||||||
|
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||||
|
|
||||||
|
const llama_adapter_cvec * cvec;
|
||||||
|
const llama_adapter_loras * loras;
|
||||||
|
const llama_memory_i * memory;
|
||||||
|
const llama_cross * cross;
|
||||||
|
|
||||||
|
const llm_graph_cb & cb_func;
|
||||||
|
|
||||||
|
std::unique_ptr<llm_graph_result> res;
|
||||||
|
|
||||||
|
llm_graph_context(const llm_graph_params & params);
|
||||||
|
|
||||||
|
int64_t n_pos_per_embd() const;
|
||||||
|
|
||||||
|
void cb(ggml_tensor * cur, const char * name, int il) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// common
|
||||||
|
//
|
||||||
|
|
||||||
|
ggml_tensor * build_cvec(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
// do mat_mul, while optionally apply lora
|
||||||
|
ggml_tensor * build_lora_mm(
|
||||||
|
ggml_tensor * w,
|
||||||
|
ggml_tensor * cur) const;
|
||||||
|
|
||||||
|
// do mat_mul_id, while optionally apply lora
|
||||||
|
ggml_tensor * build_lora_mm_id(
|
||||||
|
ggml_tensor * w, // ggml_tensor * as
|
||||||
|
ggml_tensor * cur, // ggml_tensor * b
|
||||||
|
ggml_tensor * ids) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_norm(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * mw,
|
||||||
|
ggml_tensor * mb,
|
||||||
|
llm_norm_type type,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_ffn(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * up,
|
||||||
|
ggml_tensor * up_b,
|
||||||
|
ggml_tensor * up_s,
|
||||||
|
ggml_tensor * gate,
|
||||||
|
ggml_tensor * gate_b,
|
||||||
|
ggml_tensor * gate_s,
|
||||||
|
ggml_tensor * down,
|
||||||
|
ggml_tensor * down_b,
|
||||||
|
ggml_tensor * down_s,
|
||||||
|
ggml_tensor * act_scales,
|
||||||
|
llm_ffn_op_type type_op,
|
||||||
|
llm_ffn_gate_type type_gate,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_moe_ffn(
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * gate_inp,
|
||||||
|
ggml_tensor * up_exps,
|
||||||
|
ggml_tensor * gate_exps,
|
||||||
|
ggml_tensor * down_exps,
|
||||||
|
ggml_tensor * exp_probs_b,
|
||||||
|
int64_t n_expert,
|
||||||
|
int64_t n_expert_used,
|
||||||
|
llm_ffn_op_type type_op,
|
||||||
|
bool norm_w,
|
||||||
|
bool scale_w,
|
||||||
|
float w_scale,
|
||||||
|
llama_expert_gating_func_type gating_op,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// inputs
|
||||||
|
//
|
||||||
|
|
||||||
|
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
|
||||||
|
ggml_tensor * build_inp_pos() const;
|
||||||
|
ggml_tensor * build_inp_attn_scale() const;
|
||||||
|
ggml_tensor * build_inp_out_ids() const;
|
||||||
|
ggml_tensor * build_inp_mean() const;
|
||||||
|
ggml_tensor * build_inp_cls() const;
|
||||||
|
ggml_tensor * build_inp_s_copy() const;
|
||||||
|
ggml_tensor * build_inp_s_mask() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_inp_cross_embd() const;
|
||||||
|
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||||
|
ggml_tensor * build_inp_pos_bucket_dec() const;
|
||||||
|
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// attention
|
||||||
|
//
|
||||||
|
|
||||||
|
ggml_tensor * build_attn_mha(
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
||||||
|
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
||||||
|
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * kq_mask,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
bool v_trans,
|
||||||
|
float kq_scale) const;
|
||||||
|
|
||||||
|
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_attn(
|
||||||
|
llm_graph_input_attn_no_cache * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_attn(
|
||||||
|
llm_graph_input_attn_kv_unified * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||||
|
|
||||||
|
ggml_tensor * build_attn(
|
||||||
|
llm_graph_input_attn_cross * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||||
|
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||||
|
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||||
|
ggml_tensor * kq_b,
|
||||||
|
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||||
|
float kq_scale,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// recurrent
|
||||||
|
//
|
||||||
|
|
||||||
|
ggml_tensor * build_copy_mask_state(
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * s,
|
||||||
|
ggml_tensor * state_copy,
|
||||||
|
ggml_tensor * state_mask,
|
||||||
|
int32_t n_state,
|
||||||
|
int32_t n_seqs) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_rwkv_token_shift_load(
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * state_copy,
|
||||||
|
ggml_tensor * state_mask,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
ggml_tensor * build_rwkv_token_shift_store(
|
||||||
|
ggml_tensor * token_shift,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il) const;
|
||||||
|
|
||||||
|
//
|
||||||
|
// pooling
|
||||||
|
//
|
||||||
|
|
||||||
|
void build_pooling(
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * cls,
|
||||||
|
ggml_tensor * cls_b,
|
||||||
|
ggml_tensor * cls_out,
|
||||||
|
ggml_tensor * cls_out_b) const;
|
||||||
|
};
|
@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|||||||
// corresponds to Mamba's ssm_states size
|
// corresponds to Mamba's ssm_states size
|
||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_hparams::is_swa(uint32_t il) const {
|
||||||
|
if (il < n_layer) {
|
||||||
|
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
@ -36,12 +36,17 @@ struct llama_hparams {
|
|||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
uint32_t n_rot;
|
uint32_t n_rot;
|
||||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
uint32_t n_swa = 0; // sliding window attention (SWA)
|
||||||
|
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
||||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||||
uint32_t n_expert = 0;
|
uint32_t n_expert = 0;
|
||||||
uint32_t n_expert_used = 0;
|
uint32_t n_expert_used = 0;
|
||||||
uint32_t n_rel_attn_bkts = 0;
|
uint32_t n_rel_attn_bkts = 0;
|
||||||
|
|
||||||
|
// note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
|
||||||
|
uint32_t n_embd_head_k_mla = 0;
|
||||||
|
uint32_t n_embd_head_v_mla = 0;
|
||||||
|
|
||||||
// for WavTokenizer
|
// for WavTokenizer
|
||||||
struct llama_hparams_posnet posnet;
|
struct llama_hparams_posnet posnet;
|
||||||
struct llama_hparams_convnext convnext;
|
struct llama_hparams_convnext convnext;
|
||||||
@ -61,6 +66,7 @@ struct llama_hparams {
|
|||||||
float expert_weights_scale = 0.0;
|
float expert_weights_scale = 0.0;
|
||||||
bool expert_weights_norm = false;
|
bool expert_weights_norm = false;
|
||||||
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
||||||
|
uint32_t moe_every_n_layers = 0;
|
||||||
|
|
||||||
float f_norm_eps;
|
float f_norm_eps;
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
@ -75,10 +81,16 @@ struct llama_hparams {
|
|||||||
uint32_t time_decay_extra_dim = 0;
|
uint32_t time_decay_extra_dim = 0;
|
||||||
uint32_t wkv_head_size = 0;
|
uint32_t wkv_head_size = 0;
|
||||||
uint32_t token_shift_count = 2;
|
uint32_t token_shift_count = 2;
|
||||||
|
uint32_t n_lora_decay = 0;
|
||||||
|
uint32_t n_lora_iclr = 0;
|
||||||
|
uint32_t n_lora_value_res_mix = 0;
|
||||||
|
uint32_t n_lora_gate = 0;
|
||||||
|
|
||||||
float rope_attn_factor = 1.0f;
|
float rope_attn_factor = 1.0f;
|
||||||
float rope_freq_base_train;
|
float rope_freq_base_train;
|
||||||
|
float rope_freq_base_train_swa;
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
|
float rope_freq_scale_train_swa;
|
||||||
uint32_t n_ctx_orig_yarn;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul;
|
float rope_yarn_log_mul;
|
||||||
|
|
||||||
@ -105,6 +117,14 @@ struct llama_hparams {
|
|||||||
bool use_alibi = false;
|
bool use_alibi = false;
|
||||||
bool attn_soft_cap = false;
|
bool attn_soft_cap = false;
|
||||||
|
|
||||||
|
uint32_t n_moe_layer_step = 0;
|
||||||
|
bool use_kq_norm = true;
|
||||||
|
uint32_t n_attn_chunk = 0;
|
||||||
|
// values below seems to be fixed on llama4
|
||||||
|
uint32_t n_no_rope_layer_step = 4;
|
||||||
|
uint32_t n_attn_temp_floor_scale = 8192;
|
||||||
|
float f_attn_temp_scale = 0.1;
|
||||||
|
|
||||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||||
@ -133,6 +153,8 @@ struct llama_hparams {
|
|||||||
|
|
||||||
// dimension of the recurrent state embeddings
|
// dimension of the recurrent state embeddings
|
||||||
uint32_t n_embd_v_s() const;
|
uint32_t n_embd_v_s() const;
|
||||||
|
|
||||||
|
bool is_swa(uint32_t il) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||||
|
@ -6,13 +6,13 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
#ifdef __MINGW32__
|
# if defined(__MINGW32__) && !defined(__clang__)
|
||||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
|
# else
|
||||||
|
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
|
# endif
|
||||||
#else
|
#else
|
||||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
# define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//
|
//
|
||||||
|
15
examples/talk-llama/llama-io.cpp
Normal file
15
examples/talk-llama/llama-io.cpp
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#include "llama-io.h"
|
||||||
|
|
||||||
|
void llama_io_write_i::write_string(const std::string & str) {
|
||||||
|
uint32_t str_size = str.size();
|
||||||
|
|
||||||
|
write(&str_size, sizeof(str_size));
|
||||||
|
write(str.data(), str_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_io_read_i::read_string(std::string & str) {
|
||||||
|
uint32_t str_size;
|
||||||
|
read_to(&str_size, sizeof(str_size));
|
||||||
|
|
||||||
|
str.assign((const char *) read(str_size), str_size);
|
||||||
|
}
|
35
examples/talk-llama/llama-io.h
Normal file
35
examples/talk-llama/llama-io.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
struct ggml_tensor;
|
||||||
|
|
||||||
|
class llama_io_write_i {
|
||||||
|
public:
|
||||||
|
llama_io_write_i() = default;
|
||||||
|
virtual ~llama_io_write_i() = default;
|
||||||
|
|
||||||
|
virtual void write(const void * src, size_t size) = 0;
|
||||||
|
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||||
|
|
||||||
|
// bytes written so far
|
||||||
|
virtual size_t n_bytes() = 0;
|
||||||
|
|
||||||
|
void write_string(const std::string & str);
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_io_read_i {
|
||||||
|
public:
|
||||||
|
llama_io_read_i() = default;
|
||||||
|
virtual ~llama_io_read_i() = default;
|
||||||
|
|
||||||
|
virtual const uint8_t * read(size_t size) = 0;
|
||||||
|
virtual void read_to(void * dst, size_t size) = 0;
|
||||||
|
|
||||||
|
// bytes read so far
|
||||||
|
virtual size_t n_bytes() = 0;
|
||||||
|
|
||||||
|
void read_string(std::string & str);
|
||||||
|
};
|
File diff suppressed because it is too large
Load Diff
@ -1,15 +1,51 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "llama-io.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_cparams;
|
||||||
|
struct llama_hparams;
|
||||||
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
struct llama_kv_cache : public llama_memory_i {
|
||||||
|
using llama_memory_i::llama_memory_i;
|
||||||
|
|
||||||
|
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
||||||
|
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||||
|
|
||||||
|
virtual int32_t get_n_tokens() const = 0;
|
||||||
|
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||||
|
|
||||||
|
virtual bool get_can_shift() const = 0;
|
||||||
|
|
||||||
|
bool get_can_edit() const override { return get_can_shift(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_kv_cache_guard {
|
||||||
|
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
||||||
|
|
||||||
|
~llama_kv_cache_guard() {
|
||||||
|
kv->restore();
|
||||||
|
}
|
||||||
|
|
||||||
|
void commit() {
|
||||||
|
kv->commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
llama_kv_cache * kv;
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
llama_pos delta = 0;
|
llama_pos delta = 0;
|
||||||
int32_t src = -1; // used by recurrent state models to copy states
|
int32_t src = -1; // used by recurrent state models to copy states
|
||||||
int32_t tail = -1;
|
int32_t tail = -1;
|
||||||
|
|
||||||
@ -29,15 +65,112 @@ struct llama_kv_cell {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// ring-buffer of cached KV data
|
// ring-buffer of cached KV data
|
||||||
struct llama_kv_cache {
|
// TODO: pimpl
|
||||||
|
// TODO: add notion of max sequences
|
||||||
|
class llama_kv_cache_unified : public llama_kv_cache {
|
||||||
|
public:
|
||||||
|
// can be used to query data from the model if needed
|
||||||
|
struct callbacks {
|
||||||
|
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_kv_cache_unified(
|
||||||
|
const llama_hparams & hparams,
|
||||||
|
callbacks cbs);
|
||||||
|
|
||||||
|
virtual ~llama_kv_cache_unified() = default;
|
||||||
|
|
||||||
|
// TODO: become constructor
|
||||||
|
bool init(
|
||||||
|
const llama_model & model, // TODO: do not reference the model
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
uint32_t kv_size,
|
||||||
|
bool offload);
|
||||||
|
|
||||||
|
int32_t get_n_tokens() const override;
|
||||||
|
int32_t get_used_cells() const override;
|
||||||
|
|
||||||
|
size_t total_size() const;
|
||||||
|
|
||||||
|
// TODO: better data structures to reduce the cost of this operation
|
||||||
|
llama_pos pos_max() const;
|
||||||
|
|
||||||
|
void clear() override;
|
||||||
|
void defrag() override;
|
||||||
|
|
||||||
|
virtual void restore() override;
|
||||||
|
virtual void commit() override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
// find an empty slot of size "n_tokens" in the cache
|
||||||
|
// updates the cache head
|
||||||
|
// Note: On success, it's important that cache.head points
|
||||||
|
// to the first cell of the slot.
|
||||||
|
bool find_slot(const llama_ubatch & batch);
|
||||||
|
|
||||||
|
// TODO: maybe not needed
|
||||||
|
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||||
|
|
||||||
|
// find how many cells are currently in use
|
||||||
|
uint32_t cell_max() const;
|
||||||
|
|
||||||
|
size_t size_k_bytes() const;
|
||||||
|
size_t size_v_bytes() const;
|
||||||
|
|
||||||
|
// defrag
|
||||||
|
|
||||||
|
struct {
|
||||||
|
std::vector<uint32_t> ids;
|
||||||
|
} defrag_info;
|
||||||
|
|
||||||
|
// return true if cells have been moved
|
||||||
|
bool defrag_prepare(int32_t n_max_nodes);
|
||||||
|
|
||||||
|
// commit/restore cache
|
||||||
|
|
||||||
|
struct slot_range {
|
||||||
|
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||||
|
uint32_t c1 = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// pending cell updates that are not yet committed
|
||||||
|
struct {
|
||||||
|
std::vector<slot_range> ranges;
|
||||||
|
} pending;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||||
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||||
|
|
||||||
|
// members
|
||||||
|
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
|
callbacks cbs;
|
||||||
|
|
||||||
bool has_shift = false;
|
bool has_shift = false;
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
|
|
||||||
|
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||||
|
|
||||||
bool v_trans = true; // the value tensor is transposed
|
bool v_trans = true; // the value tensor is transposed
|
||||||
bool can_shift = false;
|
bool can_shift = false;
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
// Note: The value of head isn't only used to optimize searching
|
||||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||||
// cannot be freely changed after a slot has been allocated.
|
// cannot be freely changed after a slot has been allocated.
|
||||||
uint32_t head = 0;
|
uint32_t head = 0;
|
||||||
uint32_t size = 0;
|
uint32_t size = 0;
|
||||||
@ -46,173 +179,35 @@ struct llama_kv_cache {
|
|||||||
// computed before each graph build
|
// computed before each graph build
|
||||||
uint32_t n = 0;
|
uint32_t n = 0;
|
||||||
|
|
||||||
|
std::vector<llama_kv_cell> cells;
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> k_l; // per layer
|
||||||
|
std::vector<ggml_tensor *> v_l;
|
||||||
|
|
||||||
|
private:
|
||||||
ggml_type type_k = GGML_TYPE_F16;
|
ggml_type type_k = GGML_TYPE_F16;
|
||||||
ggml_type type_v = GGML_TYPE_F16;
|
ggml_type type_v = GGML_TYPE_F16;
|
||||||
|
|
||||||
std::vector<llama_kv_cell> cells;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
|
|
||||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
|
||||||
std::vector<struct ggml_tensor *> v_l;
|
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
size_t total_size() const {
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||||
size_t size = 0;
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||||
for (const auto & buf : bufs) {
|
|
||||||
size += ggml_backend_buffer_get_size(buf.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
return size;
|
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||||
}
|
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||||
|
|
||||||
// TODO: better data structures to reduce the cost of this operation
|
|
||||||
llama_pos max_pos() const {
|
|
||||||
llama_pos max_pos = -1;
|
|
||||||
for (const auto & cell : cells) {
|
|
||||||
max_pos = std::max(max_pos, cell.pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
return max_pos;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||||
struct llama_kv_cache_slot_info {
|
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||||
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
//public:
|
||||||
bool found = false; // the slot was found
|
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||||
|
//};
|
||||||
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
|
||||||
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
|
||||||
|
|
||||||
operator bool() const { return found; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: maybe not needed
|
|
||||||
uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
|
|
||||||
|
|
||||||
bool llama_kv_cache_init(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
const llama_model & model,
|
|
||||||
const llama_cparams & cparams,
|
|
||||||
ggml_type type_k,
|
|
||||||
ggml_type type_v,
|
|
||||||
uint32_t kv_size,
|
|
||||||
bool offload);
|
|
||||||
|
|
||||||
// find an empty slot of size "n_tokens" in the cache
|
|
||||||
// updates the cache head
|
|
||||||
// returns a structure holding information about the slot found
|
|
||||||
// Note: On success, it's important that cache.head points
|
|
||||||
// to the first cell of the slot.
|
|
||||||
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
const struct llama_ubatch & batch);
|
|
||||||
|
|
||||||
// find how many cells are currently in use
|
|
||||||
uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
|
|
||||||
|
|
||||||
void llama_kv_cache_clear(struct llama_kv_cache & cache);
|
|
||||||
|
|
||||||
bool llama_kv_cache_seq_rm(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_cp(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
llama_seq_id seq_id_src,
|
|
||||||
llama_seq_id seq_id_dst,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_keep(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
llama_seq_id seq_id);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_add(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
llama_pos delta);
|
|
||||||
|
|
||||||
void llama_kv_cache_seq_div(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
llama_seq_id seq_id,
|
|
||||||
llama_pos p0,
|
|
||||||
llama_pos p1,
|
|
||||||
int d);
|
|
||||||
|
|
||||||
llama_pos llama_kv_cache_seq_pos_max(
|
|
||||||
struct llama_kv_cache & cache,
|
|
||||||
llama_seq_id seq_id);
|
|
||||||
|
|
||||||
void llama_kv_cache_defrag(struct llama_kv_cache & cache);
|
|
||||||
|
|
||||||
int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
|
|
||||||
|
|
||||||
int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
|
|
||||||
|
|
||||||
bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache view
|
// kv cache view
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
|
||||||
|
|
||||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
|
||||||
|
|
||||||
//
|
|
||||||
// kv cache restore
|
|
||||||
//
|
|
||||||
|
|
||||||
// saves the kv_cache state for future recovery.
|
|
||||||
// used to rollback llama_kv_cache_find_slot changes.
|
|
||||||
struct llama_kv_slot_restorer {
|
|
||||||
struct llama_kv_cache_state {
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t n = 0;
|
|
||||||
} old_state;
|
|
||||||
|
|
||||||
// for non-recurrent models only
|
|
||||||
// list of slots to restore
|
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
|
|
||||||
|
|
||||||
bool do_restore = false;
|
|
||||||
|
|
||||||
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
|
||||||
old_state.head = cache.head;
|
|
||||||
old_state.n = cache.n;
|
|
||||||
}
|
|
||||||
|
|
||||||
// saves a slot information for future restoration
|
|
||||||
void save(const struct llama_kv_cache_slot_info & slot) {
|
|
||||||
if (slot) {
|
|
||||||
do_restore = true;
|
|
||||||
if (slot.boundaries.first != slot.boundaries.second) {
|
|
||||||
slot_boundaries.push_back(slot.boundaries);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// must be explicitly called to restore the kv_cache state
|
|
||||||
// and rollback changes from all llama_kv_cache_find_slot calls
|
|
||||||
void restore(struct llama_kv_cache & cache) {
|
|
||||||
if (do_restore) {
|
|
||||||
cache.head = old_state.head;
|
|
||||||
cache.n = old_state.n;
|
|
||||||
|
|
||||||
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
|
||||||
llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
|
||||||
} else {
|
|
||||||
for (auto & slot : slot_boundaries) {
|
|
||||||
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
|
||||||
|
1
examples/talk-llama/llama-memory.cpp
Normal file
1
examples/talk-llama/llama-memory.cpp
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include "llama-memory.h"
|
21
examples/talk-llama/llama-memory.h
Normal file
21
examples/talk-llama/llama-memory.h
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
// general concept of LLM memory
|
||||||
|
// the KV cache is a type of LLM memory, but there can be other types
|
||||||
|
class llama_memory_i {
|
||||||
|
public:
|
||||||
|
virtual void clear() = 0;
|
||||||
|
virtual void defrag() = 0;
|
||||||
|
|
||||||
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||||
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||||
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
|
||||||
|
|
||||||
|
virtual bool get_can_edit() const = 0;
|
||||||
|
};
|
@ -8,6 +8,7 @@
|
|||||||
#include <climits>
|
#include <climits>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <cerrno>
|
#include <cerrno>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#ifdef __has_include
|
#ifdef __has_include
|
||||||
#if __has_include(<unistd.h>)
|
#if __has_include(<unistd.h>)
|
||||||
@ -34,6 +35,10 @@
|
|||||||
#include <io.h>
|
#include <io.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
#include <TargetConditionals.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
// TODO: consider moving to llama-impl.h if needed in more places
|
// TODO: consider moving to llama-impl.h if needed in more places
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
static std::string llama_format_win_err(DWORD err) {
|
static std::string llama_format_win_err(DWORD err) {
|
||||||
@ -471,7 +476,11 @@ struct llama_mlock::impl {
|
|||||||
|
|
||||||
char* errmsg = std::strerror(errno);
|
char* errmsg = std::strerror(errno);
|
||||||
bool suggest = (errno == ENOMEM);
|
bool suggest = (errno == ENOMEM);
|
||||||
|
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
|
||||||
|
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
|
||||||
|
// Skip resource limit checks on visionOS/tvOS
|
||||||
|
suggest = false;
|
||||||
|
#else
|
||||||
struct rlimit lock_limit;
|
struct rlimit lock_limit;
|
||||||
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
|
||||||
suggest = false;
|
suggest = false;
|
||||||
@ -479,6 +488,7 @@ struct llama_mlock::impl {
|
|||||||
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
|
||||||
suggest = false;
|
suggest = false;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
||||||
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -445,7 +445,8 @@ llama_model_loader::llama_model_loader(
|
|||||||
std::vector<std::string> & splits,
|
std::vector<std::string> & splits,
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool check_tensors,
|
bool check_tensors,
|
||||||
const struct llama_model_kv_override * param_overrides_p) {
|
const llama_model_kv_override * param_overrides_p,
|
||||||
|
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
|
||||||
int trace = 0;
|
int trace = 0;
|
||||||
if (getenv("LLAMA_TRACE")) {
|
if (getenv("LLAMA_TRACE")) {
|
||||||
trace = atoi(getenv("LLAMA_TRACE"));
|
trace = atoi(getenv("LLAMA_TRACE"));
|
||||||
@ -457,6 +458,8 @@ llama_model_loader::llama_model_loader(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensor_buft_overrides = param_tensor_buft_overrides_p;
|
||||||
|
|
||||||
// Load the main GGUF
|
// Load the main GGUF
|
||||||
struct ggml_context * ctx = NULL;
|
struct ggml_context * ctx = NULL;
|
||||||
struct gguf_init_params params = {
|
struct gguf_init_params params = {
|
||||||
@ -600,7 +603,9 @@ llama_model_loader::llama_model_loader(
|
|||||||
|
|
||||||
if (trace > 0) {
|
if (trace > 0) {
|
||||||
const uint16_t sid = w.idx;
|
const uint16_t sid = w.idx;
|
||||||
LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
|
LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__,
|
||||||
|
sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(),
|
||||||
|
ggml_nbytes(tensor)/1024.0f/1024.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -640,9 +645,9 @@ llama_model_loader::llama_model_loader(
|
|||||||
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
|
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
|
||||||
|
|
||||||
{
|
{
|
||||||
const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV
|
uint32_t ftype_val = 0;
|
||||||
if (kid >= 0) {
|
if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) {
|
||||||
ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid);
|
ftype = (llama_ftype) ftype_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,8 +77,9 @@ struct llama_model_loader {
|
|||||||
|
|
||||||
llama_mmaps mappings;
|
llama_mmaps mappings;
|
||||||
|
|
||||||
std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map;
|
std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
|
||||||
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
|
std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
|
||||||
|
const llama_model_tensor_buft_override * tensor_buft_overrides;
|
||||||
|
|
||||||
gguf_context_ptr meta;
|
gguf_context_ptr meta;
|
||||||
std::vector<ggml_context_ptr> contexts;
|
std::vector<ggml_context_ptr> contexts;
|
||||||
@ -95,7 +96,8 @@ struct llama_model_loader {
|
|||||||
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool check_tensors,
|
bool check_tensors,
|
||||||
const struct llama_model_kv_override * param_overrides_p);
|
const llama_model_kv_override * param_overrides_p,
|
||||||
|
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
typename std::enable_if<std::is_integral<T>::value, bool>::type
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-arch.h"
|
#include "llama-arch.h"
|
||||||
|
#include "llama-graph.h"
|
||||||
#include "llama-hparams.h"
|
#include "llama-hparams.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -10,6 +12,8 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
struct llama_cparams;
|
||||||
|
struct llama_ubatch;
|
||||||
struct llama_model_loader;
|
struct llama_model_loader;
|
||||||
|
|
||||||
// available models
|
// available models
|
||||||
@ -25,22 +29,28 @@ enum llm_type {
|
|||||||
LLM_TYPE_109M,
|
LLM_TYPE_109M,
|
||||||
LLM_TYPE_137M,
|
LLM_TYPE_137M,
|
||||||
LLM_TYPE_160M,
|
LLM_TYPE_160M,
|
||||||
|
LLM_TYPE_190M,
|
||||||
LLM_TYPE_220M,
|
LLM_TYPE_220M,
|
||||||
LLM_TYPE_250M,
|
LLM_TYPE_250M,
|
||||||
LLM_TYPE_270M,
|
LLM_TYPE_270M,
|
||||||
LLM_TYPE_335M,
|
LLM_TYPE_335M,
|
||||||
LLM_TYPE_410M,
|
LLM_TYPE_410M,
|
||||||
LLM_TYPE_450M,
|
LLM_TYPE_450M,
|
||||||
|
LLM_TYPE_475M,
|
||||||
LLM_TYPE_770M,
|
LLM_TYPE_770M,
|
||||||
LLM_TYPE_780M,
|
LLM_TYPE_780M,
|
||||||
LLM_TYPE_0_5B,
|
LLM_TYPE_0_5B,
|
||||||
|
LLM_TYPE_0_6B,
|
||||||
LLM_TYPE_1B,
|
LLM_TYPE_1B,
|
||||||
LLM_TYPE_1_3B,
|
LLM_TYPE_1_3B,
|
||||||
LLM_TYPE_1_4B,
|
LLM_TYPE_1_4B,
|
||||||
LLM_TYPE_1_5B,
|
LLM_TYPE_1_5B,
|
||||||
LLM_TYPE_1_6B,
|
LLM_TYPE_1_6B,
|
||||||
|
LLM_TYPE_1_7B,
|
||||||
|
LLM_TYPE_1_8B,
|
||||||
LLM_TYPE_2B,
|
LLM_TYPE_2B,
|
||||||
LLM_TYPE_2_8B,
|
LLM_TYPE_2_8B,
|
||||||
|
LLM_TYPE_2_9B,
|
||||||
LLM_TYPE_3B,
|
LLM_TYPE_3B,
|
||||||
LLM_TYPE_4B,
|
LLM_TYPE_4B,
|
||||||
LLM_TYPE_6B,
|
LLM_TYPE_6B,
|
||||||
@ -55,6 +65,7 @@ enum llm_type {
|
|||||||
LLM_TYPE_15B,
|
LLM_TYPE_15B,
|
||||||
LLM_TYPE_16B,
|
LLM_TYPE_16B,
|
||||||
LLM_TYPE_20B,
|
LLM_TYPE_20B,
|
||||||
|
LLM_TYPE_27B,
|
||||||
LLM_TYPE_30B,
|
LLM_TYPE_30B,
|
||||||
LLM_TYPE_32B,
|
LLM_TYPE_32B,
|
||||||
LLM_TYPE_34B,
|
LLM_TYPE_34B,
|
||||||
@ -63,6 +74,7 @@ enum llm_type {
|
|||||||
LLM_TYPE_65B,
|
LLM_TYPE_65B,
|
||||||
LLM_TYPE_70B,
|
LLM_TYPE_70B,
|
||||||
LLM_TYPE_236B,
|
LLM_TYPE_236B,
|
||||||
|
LLM_TYPE_290B,
|
||||||
LLM_TYPE_314B,
|
LLM_TYPE_314B,
|
||||||
LLM_TYPE_671B,
|
LLM_TYPE_671B,
|
||||||
LLM_TYPE_SMALL,
|
LLM_TYPE_SMALL,
|
||||||
@ -77,7 +89,10 @@ enum llm_type {
|
|||||||
LLM_TYPE_16x3_8B,
|
LLM_TYPE_16x3_8B,
|
||||||
LLM_TYPE_10B_128x3_66B,
|
LLM_TYPE_10B_128x3_66B,
|
||||||
LLM_TYPE_57B_A14B,
|
LLM_TYPE_57B_A14B,
|
||||||
LLM_TYPE_27B,
|
LLM_TYPE_17B_16E, // llama4 Scout
|
||||||
|
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||||
|
LLM_TYPE_30B_A3B,
|
||||||
|
LLM_TYPE_235B_A22B,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_layer_posnet {
|
struct llama_layer_posnet {
|
||||||
@ -161,6 +176,8 @@ struct llama_layer {
|
|||||||
struct ggml_tensor * wq_b = nullptr;
|
struct ggml_tensor * wq_b = nullptr;
|
||||||
struct ggml_tensor * wkv_a_mqa = nullptr;
|
struct ggml_tensor * wkv_a_mqa = nullptr;
|
||||||
struct ggml_tensor * wkv_b = nullptr;
|
struct ggml_tensor * wkv_b = nullptr;
|
||||||
|
struct ggml_tensor * wk_b = nullptr;
|
||||||
|
struct ggml_tensor * wv_b = nullptr;
|
||||||
struct ggml_tensor * wq_cross = nullptr;
|
struct ggml_tensor * wq_cross = nullptr;
|
||||||
struct ggml_tensor * wk_cross = nullptr;
|
struct ggml_tensor * wk_cross = nullptr;
|
||||||
struct ggml_tensor * wv_cross = nullptr;
|
struct ggml_tensor * wv_cross = nullptr;
|
||||||
@ -256,6 +273,20 @@ struct llama_layer {
|
|||||||
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
||||||
struct ggml_tensor * time_mix_gate = nullptr;
|
struct ggml_tensor * time_mix_gate = nullptr;
|
||||||
|
|
||||||
|
// rwkv7
|
||||||
|
struct ggml_tensor * time_mix_w0 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_a0 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_a1 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_a2 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_v0 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_v1 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_v2 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_g1 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_g2 = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_k_k = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_k_a = nullptr;
|
||||||
|
struct ggml_tensor * time_mix_r_k = nullptr;
|
||||||
|
|
||||||
struct ggml_tensor * time_mix_ln = nullptr;
|
struct ggml_tensor * time_mix_ln = nullptr;
|
||||||
struct ggml_tensor * time_mix_ln_b = nullptr;
|
struct ggml_tensor * time_mix_ln_b = nullptr;
|
||||||
struct ggml_tensor * time_mix_output = nullptr;
|
struct ggml_tensor * time_mix_output = nullptr;
|
||||||
@ -347,7 +378,7 @@ struct llama_model {
|
|||||||
std::string desc() const;
|
std::string desc() const;
|
||||||
|
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
size_t max_nodes() const;
|
size_t n_tensors() const;
|
||||||
size_t n_devices() const;
|
size_t n_devices() const;
|
||||||
|
|
||||||
// total number of parameters in the model
|
// total number of parameters in the model
|
||||||
@ -360,11 +391,26 @@ struct llama_model {
|
|||||||
|
|
||||||
ggml_backend_buffer_type_t select_buft(int il) const;
|
ggml_backend_buffer_type_t select_buft(int il) const;
|
||||||
|
|
||||||
|
bool has_tensor_overrides() const;
|
||||||
|
|
||||||
const struct ggml_tensor * get_tensor(const char * name) const;
|
const struct ggml_tensor * get_tensor(const char * name) const;
|
||||||
|
|
||||||
|
// TODO: move this to new llm_arch_model_i interface
|
||||||
|
llama_memory_i * create_memory() const; // TODO: params
|
||||||
|
|
||||||
|
// TODO: move this to new llm_arch_model_i interface
|
||||||
|
llm_graph_result_ptr build_graph(
|
||||||
|
const llm_graph_params & params,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
llm_graph_type type) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct impl;
|
struct impl;
|
||||||
std::unique_ptr<impl> pimpl;
|
std::unique_ptr<impl> pimpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
const char * llm_type_name(llm_type type);
|
const char * llm_type_name(llm_type type);
|
||||||
|
|
||||||
|
// For internal test use
|
||||||
|
// TODO: remove
|
||||||
|
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <regex>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
@ -47,8 +48,14 @@ struct quantize_state_impl {
|
|||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// changes to this struct must be replicated in quantize.cpp
|
||||||
|
struct tensor_quantization {
|
||||||
|
std::string name;
|
||||||
|
ggml_type quant = GGML_TYPE_COUNT;
|
||||||
|
};
|
||||||
|
|
||||||
static void llama_tensor_dequantize_impl(
|
static void llama_tensor_dequantize_impl(
|
||||||
struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers,
|
||||||
const size_t nelements, const int nthread
|
const size_t nelements, const int nthread
|
||||||
) {
|
) {
|
||||||
if (output.size() < nelements) {
|
if (output.size() < nelements) {
|
||||||
@ -527,7 +534,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> splits = {};
|
std::vector<std::string> splits = {};
|
||||||
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
|
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
|
||||||
ml.init_mappings(false); // no prefetching
|
ml.init_mappings(false); // no prefetching
|
||||||
|
|
||||||
llama_model model(llama_model_default_params());
|
llama_model model(llama_model_default_params());
|
||||||
@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
model.load_hparams(ml);
|
model.load_hparams(ml);
|
||||||
model.load_stats (ml);
|
model.load_stats (ml);
|
||||||
|
|
||||||
struct quantize_state_impl qs(model, params);
|
quantize_state_impl qs(model, params);
|
||||||
|
|
||||||
if (params->only_copy) {
|
if (params->only_copy) {
|
||||||
ftype = ml.ftype;
|
ftype = ml.ftype;
|
||||||
@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
// populate the original tensors so we get an initial meta data
|
// populate the original tensors so we get an initial meta data
|
||||||
for (const auto * it : tensors) {
|
for (const auto * it : tensors) {
|
||||||
uint16_t i_split = params->keep_split ? it->idx : 0;
|
uint16_t i_split = params->keep_split ? it->idx : 0;
|
||||||
struct ggml_tensor * tensor = it->tensor;
|
ggml_tensor * tensor = it->tensor;
|
||||||
if (!ctx_outs[i_split]) {
|
if (!ctx_outs[i_split]) {
|
||||||
ctx_outs[i_split].reset(gguf_init_empty());
|
ctx_outs[i_split].reset(gguf_init_empty());
|
||||||
}
|
}
|
||||||
@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
new_ofstream(0);
|
new_ofstream(0);
|
||||||
for (const auto * it : tensors) {
|
for (const auto * it : tensors) {
|
||||||
const auto & weight = *it;
|
const auto & weight = *it;
|
||||||
struct ggml_tensor * tensor = weight.tensor;
|
ggml_tensor * tensor = weight.tensor;
|
||||||
if (weight.idx != cur_split && params->keep_split) {
|
if (weight.idx != cur_split && params->keep_split) {
|
||||||
close_ofstream();
|
close_ofstream();
|
||||||
new_ofstream(weight.idx);
|
new_ofstream(weight.idx);
|
||||||
@ -756,10 +763,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||||
|
|
||||||
// do not quantize RWKV's time_mix_first tensors
|
// do not quantize RWKV's small yet 2D weights
|
||||||
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_w0.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
quantize &= name.find("time_mix_w1.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_v0.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_v1.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_v2.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_a0.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_a1.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_a2.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_g1.weight") == std::string::npos;
|
||||||
|
quantize &= name.find("time_mix_g2.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||||
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||||
@ -767,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
// do not quantize relative position bias (T5)
|
// do not quantize relative position bias (T5)
|
||||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||||
|
|
||||||
enum ggml_type new_type;
|
ggml_type new_type;
|
||||||
void * new_data;
|
void * new_data;
|
||||||
size_t new_size;
|
size_t new_size;
|
||||||
|
|
||||||
@ -777,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||||
if (!params->pure && ggml_is_quantized(default_type)) {
|
if (!params->pure && ggml_is_quantized(default_type)) {
|
||||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||||
|
// unless the user specifies a type
|
||||||
|
if (params->tensor_types) {
|
||||||
|
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||||
|
for (const auto & [tname, qtype] : tensor_types) {
|
||||||
|
if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) {
|
||||||
|
if (qtype != new_type) {
|
||||||
|
LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype));
|
||||||
|
}
|
||||||
|
new_type = qtype;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||||
new_type = params->token_embedding_type;
|
new_type = params->token_embedding_type;
|
||||||
@ -901,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||||||
// interface implementation
|
// interface implementation
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llama_model_quantize_params llama_model_quantize_default_params() {
|
llama_model_quantize_params llama_model_quantize_default_params() {
|
||||||
struct llama_model_quantize_params result = {
|
llama_model_quantize_params result = {
|
||||||
/*.nthread =*/ 0,
|
/*.nthread =*/ 0,
|
||||||
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
/*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
|
||||||
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
/*.output_tensor_type =*/ GGML_TYPE_COUNT,
|
||||||
@ -914,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
|||||||
/*.keep_split =*/ false,
|
/*.keep_split =*/ false,
|
||||||
/*.imatrix =*/ nullptr,
|
/*.imatrix =*/ nullptr,
|
||||||
/*.kv_overrides =*/ nullptr,
|
/*.kv_overrides =*/ nullptr,
|
||||||
|
/*.tensor_type =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -232,7 +232,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
if (k <= 0) {
|
if (k <= 0) {
|
||||||
k = cur_p->size;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
k = std::min(k, (int) cur_p->size);
|
k = std::min(k, (int) cur_p->size);
|
||||||
@ -298,6 +298,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
|
|||||||
}
|
}
|
||||||
cur_p->sorted = true;
|
cur_p->sorted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
cur_p->size = k;
|
cur_p->size = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,6 +317,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
|||||||
|
|
||||||
// llama_sampler API
|
// llama_sampler API
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
|
||||||
|
return new llama_sampler {
|
||||||
|
/* .iface = */ iface,
|
||||||
|
/* .ctx = */ ctx,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
||||||
if (!smpl->iface) {
|
if (!smpl->iface) {
|
||||||
return "(null)";
|
return "(null)";
|
||||||
@ -347,10 +355,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (smpl->ctx == nullptr) {
|
if (smpl->ctx == nullptr) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ smpl->iface,
|
/* .iface = */ smpl->iface,
|
||||||
/* .ctx = */ nullptr,
|
/* .ctx = */ nullptr
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ABORT("the sampler does not support cloning");
|
GGML_ABORT("the sampler does not support cloning");
|
||||||
@ -472,15 +480,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_chain_i,
|
/* .iface = */ &llama_sampler_chain_i,
|
||||||
/* .ctx = */ new llama_sampler_chain {
|
/* .ctx = */ new llama_sampler_chain {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .samplers = */ {},
|
/* .samplers = */ {},
|
||||||
/* .t_sample_us = */ 0,
|
/* .t_sample_us = */ 0,
|
||||||
/* .n_sample = */ 0,
|
/* .n_sample = */ 0,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||||
@ -546,10 +554,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_greedy() {
|
struct llama_sampler * llama_sampler_init_greedy() {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_greedy_i,
|
/* .iface = */ &llama_sampler_greedy_i,
|
||||||
/* .ctx = */ nullptr,
|
/* .ctx = */ nullptr
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// dist
|
// dist
|
||||||
@ -608,14 +616,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
|
|||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_dist_i,
|
/* .iface = */ &llama_sampler_dist_i,
|
||||||
/* .ctx = */ new llama_sampler_dist {
|
/* .ctx = */ new llama_sampler_dist {
|
||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
/* .seed_cur = */ seed_cur,
|
/* .seed_cur = */ seed_cur,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax
|
// softmax
|
||||||
@ -638,10 +646,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_softmax() {
|
struct llama_sampler * llama_sampler_init_softmax() {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_softmax_i,
|
/* .iface = */ &llama_sampler_softmax_i,
|
||||||
/* .ctx = */ nullptr,
|
/* .ctx = */ nullptr
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// top-k
|
// top-k
|
||||||
@ -678,12 +686,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_k_i,
|
/* .iface = */ &llama_sampler_top_k_i,
|
||||||
/* .ctx = */ new llama_sampler_top_k {
|
/* .ctx = */ new llama_sampler_top_k {
|
||||||
/* .k = */ k,
|
/* .k = */ k,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// top-p
|
// top-p
|
||||||
@ -744,13 +752,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_p_i,
|
/* .iface = */ &llama_sampler_top_p_i,
|
||||||
/* .ctx = */ new llama_sampler_top_p {
|
/* .ctx = */ new llama_sampler_top_p {
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// min-p
|
// min-p
|
||||||
@ -840,13 +848,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_min_p_i,
|
/* .iface = */ &llama_sampler_min_p_i,
|
||||||
/* .ctx = */ new llama_sampler_min_p {
|
/* .ctx = */ new llama_sampler_min_p {
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// typical
|
// typical
|
||||||
@ -939,13 +947,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_typical_i,
|
/* .iface = */ &llama_sampler_typical_i,
|
||||||
/* .ctx = */ new llama_sampler_typical {
|
/* .ctx = */ new llama_sampler_typical {
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// temp
|
// temp
|
||||||
@ -983,12 +991,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_i,
|
/* .iface = */ &llama_sampler_temp_i,
|
||||||
/* .ctx = */ new llama_sampler_temp {
|
/* .ctx = */ new llama_sampler_temp {
|
||||||
/*.temp = */ temp,
|
/*.temp = */ temp,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// temp-ext
|
// temp-ext
|
||||||
@ -1093,14 +1101,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_ext_i,
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
||||||
/* .ctx = */ new llama_sampler_temp_ext {
|
/* .ctx = */ new llama_sampler_temp_ext {
|
||||||
/* .temp = */ temp,
|
/* .temp = */ temp,
|
||||||
/* .delta = */ delta,
|
/* .delta = */ delta,
|
||||||
/* .exponent = */ exponent,
|
/* .exponent = */ exponent,
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// xtc
|
// xtc
|
||||||
@ -1185,7 +1193,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
|||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_xtc_i,
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
/* .ctx = */ new llama_sampler_xtc {
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
/* .probability = */ p,
|
/* .probability = */ p,
|
||||||
@ -1194,8 +1202,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
|||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
/* .seed_cur = */ seed_cur,
|
/* .seed_cur = */ seed_cur,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirostat
|
// mirostat
|
||||||
@ -1292,7 +1300,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
|||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_mirostat_i,
|
/* .iface = */ &llama_sampler_mirostat_i,
|
||||||
/* .ctx = */ new llama_sampler_mirostat {
|
/* .ctx = */ new llama_sampler_mirostat {
|
||||||
/* .n_vocab = */ n_vocab,
|
/* .n_vocab = */ n_vocab,
|
||||||
@ -1303,8 +1311,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
|
|||||||
/* .m = */ m,
|
/* .m = */ m,
|
||||||
/* .mu = */ 2.0f*tau,
|
/* .mu = */ 2.0f*tau,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirostat v2
|
// mirostat v2
|
||||||
@ -1391,7 +1399,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
|||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
auto seed_cur = get_rng_seed(seed);
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
/* .iface = */ &llama_sampler_mirostat_v2_i,
|
||||||
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
/* .ctx = */ new llama_sampler_mirostat_v2 {
|
||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
@ -1400,8 +1408,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
|
|||||||
/* .eta = */ eta,
|
/* .eta = */ eta,
|
||||||
/* .mu = */ 2.0f*tau,
|
/* .mu = */ 2.0f*tau,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// grammar
|
// grammar
|
||||||
@ -1442,7 +1450,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||||||
const char ** trigger_words,
|
const char ** trigger_words,
|
||||||
size_t num_trigger_words,
|
size_t num_trigger_words,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens);
|
size_t num_trigger_tokens,
|
||||||
|
const char ** trigger_patterns,
|
||||||
|
size_t num_trigger_patterns);
|
||||||
|
|
||||||
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
||||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||||
@ -1450,12 +1460,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const char *> trigger_words;
|
std::vector<const char *> trigger_patterns_c;
|
||||||
for (auto & word : ctx->grammar->trigger_words) {
|
trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
|
||||||
trigger_words.push_back(word.c_str());
|
for (auto & trigger_pattern : ctx->grammar->trigger_patterns) {
|
||||||
|
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
|
||||||
ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
|
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
|
||||||
|
|
||||||
llama_grammar_free_impl(ctx->grammar);
|
llama_grammar_free_impl(ctx->grammar);
|
||||||
@ -1465,7 +1477,8 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
|
|||||||
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
|
||||||
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
|
||||||
|
|
||||||
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
|
auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
|
GGML_ASSERT(result);
|
||||||
|
|
||||||
// copy the state
|
// copy the state
|
||||||
{
|
{
|
||||||
@ -1509,16 +1522,38 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||||||
const char ** trigger_words,
|
const char ** trigger_words,
|
||||||
size_t num_trigger_words,
|
size_t num_trigger_words,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens) {
|
size_t num_trigger_tokens,
|
||||||
|
const char ** trigger_patterns,
|
||||||
|
size_t num_trigger_patterns) {
|
||||||
auto * ctx = new llama_sampler_grammar;
|
auto * ctx = new llama_sampler_grammar;
|
||||||
|
|
||||||
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
||||||
|
// TODO: remove trigger_words support.
|
||||||
|
if (trigger_words != nullptr && num_trigger_words > 0) {
|
||||||
|
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
||||||
|
std::string trigger_pattern("[\\s\\S]*?(");
|
||||||
|
for (size_t i = 0; i < num_trigger_words; ++i) {
|
||||||
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||||
|
if (i > 0) {
|
||||||
|
trigger_pattern += "|";
|
||||||
|
}
|
||||||
|
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
||||||
|
}
|
||||||
|
trigger_pattern += ")[\\s\\S]*";
|
||||||
|
auto trigger_pattern_c = trigger_pattern.c_str();
|
||||||
|
trigger_patterns = &trigger_pattern_c;
|
||||||
|
num_trigger_patterns = 1;
|
||||||
|
}
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
/* .grammar_str = */ grammar_str,
|
/* .grammar_str = */ grammar_str,
|
||||||
/* .grammar_root = */ grammar_root,
|
/* .grammar_root = */ grammar_root,
|
||||||
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
|
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
||||||
};
|
};
|
||||||
|
if (!ctx->grammar) {
|
||||||
|
delete ctx;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
*ctx = {
|
*ctx = {
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
@ -1528,17 +1563,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_grammar_i,
|
/* .iface = */ &llama_sampler_grammar_i,
|
||||||
/* .ctx = */ ctx,
|
/* .ctx = */ ctx
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_grammar(
|
struct llama_sampler * llama_sampler_init_grammar(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root) {
|
const char * grammar_root) {
|
||||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
|
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_grammar_lazy(
|
struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||||
@ -1549,7 +1584,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy(
|
|||||||
size_t num_trigger_words,
|
size_t num_trigger_words,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens) {
|
size_t num_trigger_tokens) {
|
||||||
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
|
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
const char ** trigger_patterns,
|
||||||
|
size_t num_trigger_patterns,
|
||||||
|
const llama_token * trigger_tokens,
|
||||||
|
size_t num_trigger_tokens) {
|
||||||
|
return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
// penalties
|
// penalties
|
||||||
@ -1678,7 +1724,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||||||
float penalty_present) {
|
float penalty_present) {
|
||||||
penalty_last_n = std::max(penalty_last_n, 0);
|
penalty_last_n = std::max(penalty_last_n, 0);
|
||||||
|
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_penalties_i,
|
/* .iface = */ &llama_sampler_penalties_i,
|
||||||
/* .ctx = */ new llama_sampler_penalties {
|
/* .ctx = */ new llama_sampler_penalties {
|
||||||
/* .penalty_last_n = */ penalty_last_n,
|
/* .penalty_last_n = */ penalty_last_n,
|
||||||
@ -1687,8 +1733,75 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||||||
/* .penalty_present = */ penalty_present,
|
/* .penalty_present = */ penalty_present,
|
||||||
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
|
||||||
/* .token_count = */ {},
|
/* .token_count = */ {},
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// top-n-sigma
|
||||||
|
|
||||||
|
struct llama_sampler_top_n_sigma {
|
||||||
|
const float n;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
return "top-n-sigma";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
|
|
||||||
|
// find max logit and calculate mean
|
||||||
|
float max = cur_p->data[0].logit;
|
||||||
|
float logits_sum = 0;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].logit > max) {
|
||||||
|
max = cur_p->data[i].logit;
|
||||||
|
}
|
||||||
|
logits_sum += cur_p->data[i].logit;
|
||||||
|
}
|
||||||
|
float mean = logits_sum/cur_p->size;
|
||||||
|
|
||||||
|
// calculate standard deviation
|
||||||
|
float acc = 0;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
acc += pow(cur_p->data[i].logit - mean, 2);
|
||||||
|
}
|
||||||
|
float std = sqrt(acc/cur_p->size);
|
||||||
|
|
||||||
|
//apply mask
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (cur_p->data[i].logit < max - (ctx->n * std)) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
|
return llama_sampler_init_top_n_sigma(ctx->n);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_top_n_sigma *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||||
|
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||||
|
/* .accept = */ nullptr,
|
||||||
|
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||||
|
/* .reset = */ nullptr,
|
||||||
|
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||||
|
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||||
|
return llama_sampler_init(
|
||||||
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||||
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||||
|
/* .n = */ n,
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// DRY
|
// DRY
|
||||||
@ -2041,7 +2154,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_dry_i,
|
/* .iface = */ &llama_sampler_dry_i,
|
||||||
/* .ctx = */ new llama_sampler_dry {
|
/* .ctx = */ new llama_sampler_dry {
|
||||||
/* .total_context_size = */ context_size,
|
/* .total_context_size = */ context_size,
|
||||||
@ -2053,8 +2166,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|||||||
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||||
/* .dry_max_token_repeat = */ {},
|
/* .dry_max_token_repeat = */ {},
|
||||||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapper for test-sampling.cpp
|
// wrapper for test-sampling.cpp
|
||||||
@ -2155,14 +2268,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
|||||||
int32_t n_vocab,
|
int32_t n_vocab,
|
||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
const llama_logit_bias * logit_bias) {
|
const llama_logit_bias * logit_bias) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_logit_bias_i,
|
/* .iface = */ &llama_sampler_logit_bias_i,
|
||||||
/* .ctx = */ new llama_sampler_logit_bias {
|
/* .ctx = */ new llama_sampler_logit_bias {
|
||||||
/* .n_vocab = */ n_vocab,
|
/* .n_vocab = */ n_vocab,
|
||||||
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
||||||
/* .to_search = */ {},
|
/* .to_search = */ {},
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// infill
|
// infill
|
||||||
@ -2377,14 +2490,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
||||||
return new llama_sampler {
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_infill_i,
|
/* .iface = */ &llama_sampler_infill_i,
|
||||||
/* .ctx = */ new llama_sampler_infill {
|
/* .ctx = */ new llama_sampler_infill {
|
||||||
/* .vocab = */ vocab,
|
/* .vocab = */ vocab,
|
||||||
/* .buf0 = */ std::vector<char>(512),
|
/* .buf0 = */ std::vector<char>(512),
|
||||||
/* .buf1 = */ std::vector<char>(512),
|
/* .buf1 = */ std::vector<char>(512),
|
||||||
},
|
}
|
||||||
};
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// utils
|
// utils
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <cctype>
|
||||||
|
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
@ -341,6 +342,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||||||
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
case LLAMA_VOCAB_PRE_TYPE_JAIS:
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_TRILLION:
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
};
|
};
|
||||||
@ -392,6 +394,27 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
|
||||||
|
regex_exprs = {
|
||||||
|
// original regex from tokenizer.json
|
||||||
|
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_SUPERBPE:
|
||||||
|
regex_exprs = {
|
||||||
|
"\\p{N}+",
|
||||||
|
"(?=(\\d{3})+(?!\\d))",
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE:
|
||||||
|
regex_exprs = {
|
||||||
|
// original regex from tokenizer.json
|
||||||
|
// "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
|
||||||
|
// FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?)
|
||||||
|
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// default regex for BPE tokenization pre-processing
|
// default regex for BPE tokenization pre-processing
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
@ -1483,7 +1506,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
tokenizer_pre == "llama3" ||
|
tokenizer_pre == "llama3" ||
|
||||||
tokenizer_pre == "llama-v3" ||
|
tokenizer_pre == "llama-v3" ||
|
||||||
tokenizer_pre == "llama-bpe"||
|
tokenizer_pre == "llama-bpe"||
|
||||||
tokenizer_pre == "falcon3") {
|
tokenizer_pre == "falcon3" ||
|
||||||
|
tokenizer_pre == "pixtral") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
||||||
ignore_merges = true;
|
ignore_merges = true;
|
||||||
add_bos = true;
|
add_bos = true;
|
||||||
@ -1549,6 +1573,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
} else if (
|
} else if (
|
||||||
|
tokenizer_pre == "glm4" ||
|
||||||
tokenizer_pre == "chatglm-bpe") {
|
tokenizer_pre == "chatglm-bpe") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
|
||||||
special_bos_id = LLAMA_TOKEN_NULL;
|
special_bos_id = LLAMA_TOKEN_NULL;
|
||||||
@ -1592,6 +1617,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "megrez") {
|
tokenizer_pre == "megrez") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "gpt-4o" ||
|
||||||
|
tokenizer_pre == "llama4") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||||
|
clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "superbpe") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
|
||||||
|
clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "trillion") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
|
||||||
|
clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "bailingmoe") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||||
|
clean_spaces = false;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
@ -1769,6 +1811,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
|| t.first == "<end_of_turn>"
|
|| t.first == "<end_of_turn>"
|
||||||
|| t.first == "<|endoftext|>"
|
|| t.first == "<|endoftext|>"
|
||||||
|| t.first == "<EOT>"
|
|| t.first == "<EOT>"
|
||||||
|
|| t.first == "_<EOT>"
|
||||||
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
||||||
) {
|
) {
|
||||||
special_eot_id = t.second;
|
special_eot_id = t.second;
|
||||||
@ -1799,8 +1842,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
if (false
|
if (false
|
||||||
|| t.first == "<|fim_prefix|>" // Qwen
|
|| t.first == "<|fim_prefix|>" // Qwen
|
||||||
|| t.first == "<fim-prefix>"
|
|| t.first == "<fim-prefix>"
|
||||||
|
|| t.first == "<fim_prefix>" // Granite
|
||||||
|| t.first == "<|fim▁begin|>" // DeepSeek
|
|| t.first == "<|fim▁begin|>" // DeepSeek
|
||||||
|| t.first == "<PRE>"
|
|| t.first == "<PRE>"
|
||||||
|
|| t.first == "▁<PRE>" // CodeLlama
|
||||||
) {
|
) {
|
||||||
special_fim_pre_id = t.second;
|
special_fim_pre_id = t.second;
|
||||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
@ -1816,8 +1861,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
if (false
|
if (false
|
||||||
|| t.first == "<|fim_suffix|>" // Qwen
|
|| t.first == "<|fim_suffix|>" // Qwen
|
||||||
|| t.first == "<fim-suffix>"
|
|| t.first == "<fim-suffix>"
|
||||||
|
|| t.first == "<fim_suffix>" // Granite
|
||||||
|| t.first == "<|fim▁hole|>" // DeepSeek
|
|| t.first == "<|fim▁hole|>" // DeepSeek
|
||||||
|| t.first == "<SUF>"
|
|| t.first == "<SUF>"
|
||||||
|
|| t.first == "▁<SUF>" // CodeLlama
|
||||||
) {
|
) {
|
||||||
special_fim_suf_id = t.second;
|
special_fim_suf_id = t.second;
|
||||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
@ -1833,8 +1880,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
if (false
|
if (false
|
||||||
|| t.first == "<|fim_middle|>" // Qwen
|
|| t.first == "<|fim_middle|>" // Qwen
|
||||||
|| t.first == "<fim-middle>"
|
|| t.first == "<fim-middle>"
|
||||||
|
|| t.first == "<fim_middle>" // Granite
|
||||||
|| t.first == "<|fim▁end|>" // DeepSeek
|
|| t.first == "<|fim▁end|>" // DeepSeek
|
||||||
|| t.first == "<MID>"
|
|| t.first == "<MID>"
|
||||||
|
|| t.first == "▁<MID>" // CodeLlama
|
||||||
) {
|
) {
|
||||||
special_fim_mid_id = t.second;
|
special_fim_mid_id = t.second;
|
||||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
@ -1850,6 +1899,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
if (false
|
if (false
|
||||||
|| t.first == "<|fim_pad|>" // Qwen
|
|| t.first == "<|fim_pad|>" // Qwen
|
||||||
|| t.first == "<fim-pad>"
|
|| t.first == "<fim-pad>"
|
||||||
|
|| t.first == "<fim_pad>" // Granite
|
||||||
|| t.first == "<PAD>"
|
|| t.first == "<PAD>"
|
||||||
) {
|
) {
|
||||||
special_fim_pad_id = t.second;
|
special_fim_pad_id = t.second;
|
||||||
@ -1868,6 +1918,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
|| t.first == "<|repo_name|>"
|
|| t.first == "<|repo_name|>"
|
||||||
|| t.first == "<fim-repo>"
|
|| t.first == "<fim-repo>"
|
||||||
|| t.first == "<REPO>"
|
|| t.first == "<REPO>"
|
||||||
|
|| t.first == "<reponame>" // Granite
|
||||||
) {
|
) {
|
||||||
special_fim_rep_id = t.second;
|
special_fim_rep_id = t.second;
|
||||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
@ -1919,6 +1970,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
|| t.first == "<|endoftext|>"
|
|| t.first == "<|endoftext|>"
|
||||||
|| t.first == "<|eom_id|>"
|
|| t.first == "<|eom_id|>"
|
||||||
|| t.first == "<EOT>"
|
|| t.first == "<EOT>"
|
||||||
|
|| t.first == "_<EOT>"
|
||||||
) {
|
) {
|
||||||
special_eog_ids.insert(t.second);
|
special_eog_ids.insert(t.second);
|
||||||
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
||||||
@ -2177,14 +2229,12 @@ void llama_vocab::impl::tokenizer_st_partition(std::forward_list<fragment_buffer
|
|||||||
// find the first occurrence of a given special token in this fragment
|
// find the first occurrence of a given special token in this fragment
|
||||||
// passing offset argument only limit the "search area" but match coordinates
|
// passing offset argument only limit the "search area" but match coordinates
|
||||||
// are still relative to the source full raw_text
|
// are still relative to the source full raw_text
|
||||||
auto match = raw_text.find(text, raw_text_base_offset);
|
// string_view begins at pos 0 for the same reason
|
||||||
|
auto match = std::string_view(raw_text.data(), raw_text_base_offset + raw_text_base_length).find(text, raw_text_base_offset);
|
||||||
|
|
||||||
// no occurrences found, stop processing this fragment for a given special token
|
// no occurrences found, stop processing this fragment for a given special token
|
||||||
if (match == std::string::npos) break;
|
if (match == std::string::npos) break;
|
||||||
|
|
||||||
// check if match is within bounds of offset <-> length
|
|
||||||
if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
|
|
||||||
|
|
||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
|
LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
|
||||||
#endif
|
#endif
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -60,6 +60,7 @@ extern "C" {
|
|||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
struct llama_sampler;
|
struct llama_sampler;
|
||||||
|
struct llama_kv_cache;
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
typedef int32_t llama_token;
|
typedef int32_t llama_token;
|
||||||
@ -105,6 +106,12 @@ extern "C" {
|
|||||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llama_rope_type {
|
enum llama_rope_type {
|
||||||
@ -213,7 +220,7 @@ extern "C" {
|
|||||||
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
|
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
@ -275,10 +282,18 @@ extern "C" {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_model_tensor_buft_override {
|
||||||
|
const char * pattern;
|
||||||
|
ggml_backend_buffer_type_t buft;
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_model_params {
|
struct llama_model_params {
|
||||||
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
|
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
|
||||||
ggml_backend_dev_t * devices;
|
ggml_backend_dev_t * devices;
|
||||||
|
|
||||||
|
// NULL-terminated list of buffer types to use for tensors that match a pattern
|
||||||
|
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
|
||||||
|
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||||
|
|
||||||
@ -307,7 +322,7 @@ extern "C" {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||||
// https://github.com/ggerganov/llama.cpp/pull/7544
|
// https://github.com/ggml-org/llama.cpp/pull/7544
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
uint32_t n_ctx; // text context, 0 = from model
|
uint32_t n_ctx; // text context, 0 = from model
|
||||||
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
||||||
@ -320,7 +335,7 @@ extern "C" {
|
|||||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||||
enum llama_attention_type attention_type; // attention type to use for embeddings
|
enum llama_attention_type attention_type; // attention type to use for embeddings
|
||||||
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
// ref: https://github.com/ggml-org/llama.cpp/pull/2054
|
||||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||||
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
|
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
|
||||||
@ -353,17 +368,18 @@ extern "C" {
|
|||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
typedef struct llama_model_quantize_params {
|
typedef struct llama_model_quantize_params {
|
||||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||||
enum ggml_type output_tensor_type; // output tensor type
|
enum ggml_type output_tensor_type; // output tensor type
|
||||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||||
bool quantize_output_tensor; // quantize output.weight
|
bool quantize_output_tensor; // quantize output.weight
|
||||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||||
bool pure; // quantize all tensors to the default type
|
bool pure; // quantize all tensors to the default type
|
||||||
bool keep_split; // quantize to the same number of shards
|
bool keep_split; // quantize to the same number of shards
|
||||||
void * imatrix; // pointer to importance matrix data
|
void * imatrix; // pointer to importance matrix data
|
||||||
void * kv_overrides; // pointer to vector containing overrides
|
void * kv_overrides; // pointer to vector containing overrides
|
||||||
|
void * tensor_types; // pointer to vector containing tensor types
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
typedef struct llama_logit_bias {
|
typedef struct llama_logit_bias {
|
||||||
@ -385,7 +401,7 @@ extern "C" {
|
|||||||
struct llama_adapter_lora;
|
struct llama_adapter_lora;
|
||||||
|
|
||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
// TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172)
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
||||||
@ -468,7 +484,8 @@ extern "C" {
|
|||||||
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
|
||||||
|
|
||||||
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
|
||||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
|
||||||
|
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
|
||||||
|
|
||||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||||
@ -477,6 +494,7 @@ extern "C" {
|
|||||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||||
|
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||||
|
|
||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||||
@ -584,7 +602,7 @@ extern "C" {
|
|||||||
// KV cache
|
// KV cache
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: remove llama_kv_cache_view_* API
|
// TODO: start using struct llama_kv_cache
|
||||||
|
|
||||||
// Information associated with an individual cell in the KV cache view.
|
// Information associated with an individual cell in the KV cache view.
|
||||||
struct llama_kv_cache_view_cell {
|
struct llama_kv_cache_view_cell {
|
||||||
@ -639,13 +657,19 @@ extern "C" {
|
|||||||
|
|
||||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||||
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||||
|
"use llama_kv_self_n_tokens instead");
|
||||||
|
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
|
||||||
|
"use llama_kv_self_used_cells instead");
|
||||||
|
|
||||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API void llama_kv_self_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
@ -653,7 +677,7 @@ extern "C" {
|
|||||||
// seq_id < 0 : match any sequence
|
// seq_id < 0 : match any sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
LLAMA_API bool llama_kv_self_seq_rm(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
@ -663,7 +687,7 @@ extern "C" {
|
|||||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_cp(
|
LLAMA_API void llama_kv_self_seq_cp(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id_src,
|
llama_seq_id seq_id_src,
|
||||||
llama_seq_id seq_id_dst,
|
llama_seq_id seq_id_dst,
|
||||||
@ -671,17 +695,17 @@ extern "C" {
|
|||||||
llama_pos p1);
|
llama_pos p1);
|
||||||
|
|
||||||
// Removes all tokens that do not belong to the specified sequence
|
// Removes all tokens that do not belong to the specified sequence
|
||||||
LLAMA_API void llama_kv_cache_seq_keep(
|
LLAMA_API void llama_kv_self_seq_keep(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_self_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_add(
|
LLAMA_API void llama_kv_self_seq_add(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
@ -691,10 +715,10 @@ extern "C" {
|
|||||||
// Integer division of the positions by factor of `d > 1`
|
// Integer division of the positions by factor of `d > 1`
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_self_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_div(
|
LLAMA_API void llama_kv_self_seq_div(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
@ -702,24 +726,76 @@ extern "C" {
|
|||||||
int d);
|
int d);
|
||||||
|
|
||||||
// Returns the largest position present in the KV cache for the specified sequence
|
// Returns the largest position present in the KV cache for the specified sequence
|
||||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
|
|
||||||
// how to avoid this?
|
|
||||||
|
|
||||||
// Defragment the KV cache
|
// Defragment the KV cache
|
||||||
// This will be applied:
|
// This will be applied:
|
||||||
// - lazily on next llama_decode()
|
// - lazily on next llama_decode()
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_self_update()
|
||||||
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
|
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
|
||||||
|
|
||||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
|
||||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Check if the context supports KV cache shifting
|
// Check if the context supports KV cache shifting
|
||||||
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
|
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||||
|
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_clear(
|
||||||
|
struct llama_context * ctx),
|
||||||
|
"use llama_kv_self_clear instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1),
|
||||||
|
"use llama_kv_self_seq_rm instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1),
|
||||||
|
"use llama_kv_self_seq_cp instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id),
|
||||||
|
"use llama_kv_self_seq_keep instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta),
|
||||||
|
"use llama_kv_self_seq_add instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d),
|
||||||
|
"use llama_kv_self_seq_div instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id),
|
||||||
|
"use llama_kv_self_seq_pos_max instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
|
||||||
|
"use llama_kv_self_defrag instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
|
||||||
|
"use llama_kv_self_can_shift instead");
|
||||||
|
|
||||||
|
DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
|
||||||
|
"use llama_kv_self_update instead");
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// State / sessions
|
// State / sessions
|
||||||
@ -883,6 +959,10 @@ extern "C" {
|
|||||||
// If set to true, the model will only attend to the past tokens
|
// If set to true, the model will only attend to the past tokens
|
||||||
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
||||||
|
|
||||||
|
// Set whether the model is in warmup mode or not
|
||||||
|
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
|
||||||
|
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
|
||||||
|
|
||||||
// Set abort callback
|
// Set abort callback
|
||||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
|
||||||
@ -1040,7 +1120,7 @@ extern "C" {
|
|||||||
|
|
||||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||||
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||||
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
||||||
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
||||||
/// @param chat Pointer to a list of multiple llama_chat_message
|
/// @param chat Pointer to a list of multiple llama_chat_message
|
||||||
/// @param n_msg Number of llama_chat_message in this chat
|
/// @param n_msg Number of llama_chat_message in this chat
|
||||||
@ -1114,11 +1194,12 @@ extern "C" {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler {
|
struct llama_sampler {
|
||||||
struct llama_sampler_i * iface;
|
const struct llama_sampler_i * iface;
|
||||||
llama_sampler_context_t ctx;
|
llama_sampler_context_t ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
// mirror of llama_sampler_i:
|
// mirror of llama_sampler_i:
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
@ -1148,15 +1229,16 @@ extern "C" {
|
|||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||||
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
||||||
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
|
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
|
/// Setting k <= 0 makes this a noop
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||||
|
|
||||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
||||||
|
|
||||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
/// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
||||||
|
|
||||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
@ -1171,6 +1253,9 @@ extern "C" {
|
|||||||
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
|
||||||
|
|
||||||
|
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
@ -1194,22 +1279,38 @@ extern "C" {
|
|||||||
float tau,
|
float tau,
|
||||||
float eta);
|
float eta);
|
||||||
|
|
||||||
|
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
|
||||||
|
/// @param vocab The vocabulary that this grammar will be used with.
|
||||||
|
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
|
||||||
|
/// @param grammar_root The name of the start symbol for the grammar.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root);
|
const char * grammar_root);
|
||||||
|
|
||||||
/// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639
|
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
||||||
/// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future.
|
|
||||||
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler.
|
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy(
|
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_vocab * vocab,
|
||||||
const char * grammar_str,
|
const char * grammar_str,
|
||||||
const char * grammar_root,
|
const char * grammar_root,
|
||||||
const char ** trigger_words,
|
const char ** trigger_words,
|
||||||
size_t num_trigger_words,
|
size_t num_trigger_words,
|
||||||
const llama_token * trigger_tokens,
|
const llama_token * trigger_tokens,
|
||||||
size_t num_trigger_tokens);
|
size_t num_trigger_tokens),
|
||||||
|
"use llama_sampler_init_grammar_lazy_patterns instead");
|
||||||
|
|
||||||
|
|
||||||
|
/// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639
|
||||||
|
/// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group.
|
||||||
|
/// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included.
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const char * grammar_str,
|
||||||
|
const char * grammar_root,
|
||||||
|
const char ** trigger_patterns,
|
||||||
|
size_t num_trigger_patterns,
|
||||||
|
const llama_token * trigger_tokens,
|
||||||
|
size_t num_trigger_tokens);
|
||||||
|
|
||||||
|
|
||||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||||
|
@ -618,7 +618,14 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|||||||
result.reserve(utf8.size());
|
result.reserve(utf8.size());
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
while (offset < utf8.size()) {
|
while (offset < utf8.size()) {
|
||||||
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
try {
|
||||||
|
result.push_back(unicode_cpt_from_utf8(utf8, offset));
|
||||||
|
}
|
||||||
|
catch (const std::invalid_argument & /*ex*/) {
|
||||||
|
// Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
|
||||||
|
++offset;
|
||||||
|
result.emplace_back(0xFFFD); // replacement character
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -701,7 +708,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||||||
const auto cpts = unicode_cpts_from_utf8(text);
|
const auto cpts = unicode_cpts_from_utf8(text);
|
||||||
|
|
||||||
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
// ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
|
||||||
std::string text_collapsed;
|
std::string text_collapsed;
|
||||||
if (need_collapse) {
|
if (need_collapse) {
|
||||||
// collapse all unicode categories
|
// collapse all unicode categories
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#
|
#
|
||||||
# Transcribe twitch.tv livestream by feeding audio input to whisper.cpp at regular intervals
|
# Transcribe twitch.tv livestream by feeding audio input to whisper.cpp at regular intervals
|
||||||
# Thanks to @keyehzy
|
# Thanks to @keyehzy
|
||||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/209
|
# ref: https://github.com/ggml-org/whisper.cpp/issues/209
|
||||||
#
|
#
|
||||||
# The script currently depends on the third-party tool "streamlink"
|
# The script currently depends on the third-party tool "streamlink"
|
||||||
# On Mac OS, you can install it via "brew install streamlink"
|
# On Mac OS, you can install it via "brew install streamlink"
|
||||||
|
@ -14,6 +14,8 @@ set(SOURCE_FILES
|
|||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.cpp
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/unary-ops.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/unary-ops.cpp
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/binary-ops.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/binary-ops.cpp
|
||||||
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/vec.cpp
|
||||||
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ops.cpp
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-alloc.c
|
${WHISPER_LIB_DIR}/ggml/src/ggml-alloc.c
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-backend.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-backend.cpp
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-backend-reg.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-backend-reg.cpp
|
||||||
|
@ -34,6 +34,8 @@ if (NOT GGML_HOME)
|
|||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-traits.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-traits.cpp
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/unary-ops.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/unary-ops.cpp
|
||||||
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/binary-ops.cpp
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/binary-ops.cpp
|
||||||
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/vec.cpp
|
||||||
|
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ops.cpp
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
# This simple script is called by Neovim to capture audio from the microphone and transcribe it with Whisper.
|
# This simple script is called by Neovim to capture audio from the microphone and transcribe it with Whisper.
|
||||||
# In order for this to work, you need to clone the whisper.cpp repo and build the 'stream' tool
|
# In order for this to work, you need to clone the whisper.cpp repo and build the 'stream' tool
|
||||||
#
|
#
|
||||||
# git clone https://github.com/ggerganov/whisper.cpp
|
# git clone https://github.com/ggml-org/whisper.cpp
|
||||||
# cd whisper.cpp
|
# cd whisper.cpp
|
||||||
# make stream
|
# make stream
|
||||||
#
|
#
|
||||||
@ -31,7 +31,7 @@
|
|||||||
model="base.en"
|
model="base.en"
|
||||||
|
|
||||||
# export the path to the whisper.cpp repo in the WHISPER_CPP_HOME env variable
|
# export the path to the whisper.cpp repo in the WHISPER_CPP_HOME env variable
|
||||||
# https://github.com/ggerganov/whisper.cpp
|
# https://github.com/ggml-org/whisper.cpp
|
||||||
cd "${WHISPER_CPP_HOME}"
|
cd "${WHISPER_CPP_HOME}"
|
||||||
|
|
||||||
if [ ! -f ./stream ] ; then
|
if [ ! -f ./stream ] ; then
|
||||||
|
@ -36,7 +36,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
|||||||
-s MAXIMUM_MEMORY=2000MB \
|
-s MAXIMUM_MEMORY=2000MB \
|
||||||
-s ALLOW_MEMORY_GROWTH=1 \
|
-s ALLOW_MEMORY_GROWTH=1 \
|
||||||
-s FORCE_FILESYSTEM=1 \
|
-s FORCE_FILESYSTEM=1 \
|
||||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \
|
||||||
${EXTRA_FLAGS} \
|
${EXTRA_FLAGS} \
|
||||||
")
|
")
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ Link: https://ggerganov.github.io/whisper.cpp/
|
|||||||
|
|
||||||
```bash (v3.1.2)
|
```bash (v3.1.2)
|
||||||
# build using Emscripten
|
# build using Emscripten
|
||||||
git clone https://github.com/ggerganov/whisper.cpp
|
git clone https://github.com/ggml-org/whisper.cpp
|
||||||
cd whisper.cpp
|
cd whisper.cpp
|
||||||
mkdir build-em && cd build-em
|
mkdir build-em && cd build-em
|
||||||
emcmake cmake ..
|
emcmake cmake ..
|
||||||
|
@ -65,13 +65,14 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct whisper_full_params params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
struct whisper_full_params params = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
||||||
|
bool is_multilingual = whisper_is_multilingual(g_contexts[index]);
|
||||||
|
|
||||||
params.print_realtime = true;
|
params.print_realtime = true;
|
||||||
params.print_progress = false;
|
params.print_progress = false;
|
||||||
params.print_timestamps = true;
|
params.print_timestamps = true;
|
||||||
params.print_special = false;
|
params.print_special = false;
|
||||||
params.translate = translate;
|
params.translate = translate;
|
||||||
params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en";
|
params.language = is_multilingual ? strdup(lang.c_str()) : "en";
|
||||||
params.n_threads = std::min(nthreads, std::min(16, mpow2(std::thread::hardware_concurrency())));
|
params.n_threads = std::min(nthreads, std::min(16, mpow2(std::thread::hardware_concurrency())));
|
||||||
params.offset_ms = 0;
|
params.offset_ms = 0;
|
||||||
|
|
||||||
@ -102,10 +103,13 @@ EMSCRIPTEN_BINDINGS(whisper) {
|
|||||||
|
|
||||||
// run the worker
|
// run the worker
|
||||||
{
|
{
|
||||||
g_worker = std::thread([index, params, pcmf32 = std::move(pcmf32)]() {
|
g_worker = std::thread([index, params, pcmf32 = std::move(pcmf32), is_multilingual]() {
|
||||||
whisper_reset_timings(g_contexts[index]);
|
whisper_reset_timings(g_contexts[index]);
|
||||||
whisper_full(g_contexts[index], params, pcmf32.data(), pcmf32.size());
|
whisper_full(g_contexts[index], params, pcmf32.data(), pcmf32.size());
|
||||||
whisper_print_timings(g_contexts[index]);
|
whisper_print_timings(g_contexts[index]);
|
||||||
|
if (is_multilingual) {
|
||||||
|
free((void*)params.language);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,12 +25,12 @@
|
|||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
# Small shell script to more easily automatically download and transcribe live stream VODs.
|
# Small shell script to more easily automatically download and transcribe live stream VODs.
|
||||||
# This uses YT-DLP, ffmpeg and the CPP version of Whisper: https://github.com/ggerganov/whisper.cpp
|
# This uses YT-DLP, ffmpeg and the CPP version of Whisper: https://github.com/ggml-org/whisper.cpp
|
||||||
# Use `./examples/yt-wsp.sh help` to print help info.
|
# Use `./examples/yt-wsp.sh help` to print help info.
|
||||||
#
|
#
|
||||||
# Sample usage:
|
# Sample usage:
|
||||||
#
|
#
|
||||||
# git clone https://github.com/ggerganov/whisper.cpp
|
# git clone https://github.com/ggml-org/whisper.cpp
|
||||||
# cd whisper.cpp
|
# cd whisper.cpp
|
||||||
# make
|
# make
|
||||||
# ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890
|
# ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890
|
||||||
@ -44,7 +44,7 @@ SCRIPT_DIR="${SCRIPT_PATH%/*}"
|
|||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Documentation on downloading models can be found in the whisper.cpp repo:
|
# Documentation on downloading models can be found in the whisper.cpp repo:
|
||||||
# https://github.com/ggerganov/whisper.cpp/#usage
|
# https://github.com/ggml-org/whisper.cpp/#usage
|
||||||
#
|
#
|
||||||
# note: unless a multilingual model is specified, WHISPER_LANG will be ignored
|
# note: unless a multilingual model is specified, WHISPER_LANG will be ignored
|
||||||
# and the video will be transcribed as if the audio were in the English language
|
# and the video will be transcribed as if the audio were in the English language
|
||||||
@ -103,10 +103,10 @@ check_requirements() {
|
|||||||
fi;
|
fi;
|
||||||
|
|
||||||
if ! command -v "${WHISPER_EXECUTABLE}" &>/dev/null; then
|
if ! command -v "${WHISPER_EXECUTABLE}" &>/dev/null; then
|
||||||
echo "The C++ implementation of Whisper is required: https://github.com/ggerganov/whisper.cpp"
|
echo "The C++ implementation of Whisper is required: https://github.com/ggml-org/whisper.cpp"
|
||||||
echo "Sample usage:";
|
echo "Sample usage:";
|
||||||
echo "";
|
echo "";
|
||||||
echo " git clone https://github.com/ggerganov/whisper.cpp";
|
echo " git clone https://github.com/ggml-org/whisper.cpp";
|
||||||
echo " cd whisper.cpp";
|
echo " cd whisper.cpp";
|
||||||
echo " make";
|
echo " make";
|
||||||
echo " ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890";
|
echo " ./examples/yt-wsp.sh https://www.youtube.com/watch?v=1234567890";
|
||||||
|
@ -107,6 +107,7 @@ message(DEBUG "INS_ENB : ${INS_ENB}")
|
|||||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||||
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||||
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
||||||
|
option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
|
||||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||||
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
||||||
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
||||||
@ -170,7 +171,6 @@ option(GGML_HIP "ggml: use HIP"
|
|||||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||||
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
||||||
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
|
||||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||||
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
||||||
@ -360,3 +360,29 @@ write_basic_package_version_file(
|
|||||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
set(MSVC_WARNING_FLAGS
|
||||||
|
/wd4005 # Macro redefinition
|
||||||
|
/wd4244 # Conversion from one type to another type, possible loss of data
|
||||||
|
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
|
||||||
|
/wd4996 # Disable POSIX deprecation warnings
|
||||||
|
/wd4702 # Unreachable code warnings
|
||||||
|
)
|
||||||
|
function(disable_msvc_warnings target_name)
|
||||||
|
if(TARGET ${target_name})
|
||||||
|
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
disable_msvc_warnings(ggml-base)
|
||||||
|
disable_msvc_warnings(ggml)
|
||||||
|
disable_msvc_warnings(ggml-cpu)
|
||||||
|
disable_msvc_warnings(ggml-cpu-x64)
|
||||||
|
disable_msvc_warnings(ggml-cpu-sse42)
|
||||||
|
disable_msvc_warnings(ggml-cpu-sandybridge)
|
||||||
|
disable_msvc_warnings(ggml-cpu-haswell)
|
||||||
|
disable_msvc_warnings(ggml-cpu-skylakex)
|
||||||
|
disable_msvc_warnings(ggml-cpu-icelake)
|
||||||
|
disable_msvc_warnings(ggml-cpu-alderlake)
|
||||||
|
endif()
|
||||||
|
@ -24,7 +24,7 @@ typedef std::unique_ptr<gguf_context, gguf_context_deleter> gguf_context_ptr;
|
|||||||
|
|
||||||
struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
|
struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
|
||||||
|
|
||||||
typedef std::unique_ptr<ggml_gallocr_t, ggml_gallocr_deleter> ggml_gallocr_ptr;
|
typedef std::unique_ptr<ggml_gallocr, ggml_gallocr_deleter> ggml_gallocr_ptr;
|
||||||
|
|
||||||
// ggml-backend
|
// ggml-backend
|
||||||
|
|
||||||
|
@ -133,6 +133,11 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define RPC_PROTO_MAJOR_VERSION 2
|
||||||
|
#define RPC_PROTO_MINOR_VERSION 0
|
||||||
|
#define RPC_PROTO_PATCH_VERSION 0
|
||||||
#define GGML_RPC_MAX_SERVERS 16
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
||||||
// backend API
|
// backend API
|
||||||
|
@ -393,8 +393,8 @@ extern "C" {
|
|||||||
|
|
||||||
// precision
|
// precision
|
||||||
enum ggml_prec {
|
enum ggml_prec {
|
||||||
GGML_PREC_DEFAULT,
|
GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default
|
||||||
GGML_PREC_F32,
|
GGML_PREC_F32 = 10,
|
||||||
};
|
};
|
||||||
|
|
||||||
// model file types
|
// model file types
|
||||||
@ -481,6 +481,7 @@ extern "C" {
|
|||||||
GGML_OP_CONV_TRANSPOSE_1D,
|
GGML_OP_CONV_TRANSPOSE_1D,
|
||||||
GGML_OP_IM2COL,
|
GGML_OP_IM2COL,
|
||||||
GGML_OP_IM2COL_BACK,
|
GGML_OP_IM2COL_BACK,
|
||||||
|
GGML_OP_CONV_2D_DW,
|
||||||
GGML_OP_CONV_TRANSPOSE_2D,
|
GGML_OP_CONV_TRANSPOSE_2D,
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
@ -507,17 +508,12 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
GGML_OP_MAP_UNARY,
|
|
||||||
GGML_OP_MAP_BINARY,
|
|
||||||
|
|
||||||
GGML_OP_MAP_CUSTOM1_F32,
|
|
||||||
GGML_OP_MAP_CUSTOM2_F32,
|
|
||||||
GGML_OP_MAP_CUSTOM3_F32,
|
|
||||||
|
|
||||||
GGML_OP_MAP_CUSTOM1,
|
GGML_OP_MAP_CUSTOM1,
|
||||||
GGML_OP_MAP_CUSTOM2,
|
GGML_OP_MAP_CUSTOM2,
|
||||||
GGML_OP_MAP_CUSTOM3,
|
GGML_OP_MAP_CUSTOM3,
|
||||||
|
|
||||||
|
GGML_OP_CUSTOM,
|
||||||
|
|
||||||
GGML_OP_CROSS_ENTROPY_LOSS,
|
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||||
GGML_OP_OPT_STEP_ADAMW,
|
GGML_OP_OPT_STEP_ADAMW,
|
||||||
@ -682,6 +678,9 @@ extern "C" {
|
|||||||
GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
|
GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
|
||||||
GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
|
GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
|
||||||
|
|
||||||
|
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
|
||||||
|
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
|
|
||||||
@ -1665,7 +1664,7 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// depthwise
|
// depthwise (via im2col and mul_mat)
|
||||||
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
|
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a, // convolution kernel
|
struct ggml_tensor * a, // convolution kernel
|
||||||
@ -1677,6 +1676,22 @@ extern "C" {
|
|||||||
int d0, // dilation dimension 0
|
int d0, // dilation dimension 0
|
||||||
int d1); // dilation dimension 1
|
int d1); // dilation dimension 1
|
||||||
|
|
||||||
|
// Depthwise 2D convolution
|
||||||
|
// may be faster than ggml_conv_2d_dw, but not available in all backends
|
||||||
|
// a: KW KH 1 C convolution kernel
|
||||||
|
// b: W H C N input data
|
||||||
|
// res: W_out H_out C N
|
||||||
|
GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
int stride0,
|
||||||
|
int stride1,
|
||||||
|
int pad0,
|
||||||
|
int pad1,
|
||||||
|
int dilation0,
|
||||||
|
int dilation1);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@ -1722,24 +1737,29 @@ extern "C" {
|
|||||||
float p0,
|
float p0,
|
||||||
float p1);
|
float p1);
|
||||||
|
|
||||||
// nearest interpolate
|
enum ggml_scale_mode {
|
||||||
|
GGML_SCALE_MODE_NEAREST = 0,
|
||||||
|
GGML_SCALE_MODE_BILINEAR = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
// interpolate
|
||||||
// multiplies ne0 and ne1 by scale factor
|
// multiplies ne0 and ne1 by scale factor
|
||||||
// used in stable-diffusion
|
|
||||||
GGML_API struct ggml_tensor * ggml_upscale(
|
GGML_API struct ggml_tensor * ggml_upscale(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int scale_factor);
|
int scale_factor,
|
||||||
|
enum ggml_scale_mode mode);
|
||||||
|
|
||||||
// nearest interpolate
|
// interpolate
|
||||||
// nearest interpolate to specified dimensions
|
// interpolate scale to specified dimensions
|
||||||
// used in tortoise.cpp
|
|
||||||
GGML_API struct ggml_tensor * ggml_upscale_ext(
|
GGML_API struct ggml_tensor * ggml_upscale_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int ne0,
|
int ne0,
|
||||||
int ne1,
|
int ne1,
|
||||||
int ne2,
|
int ne2,
|
||||||
int ne3);
|
int ne3,
|
||||||
|
enum ggml_scale_mode mode);
|
||||||
|
|
||||||
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
||||||
GGML_API struct ggml_tensor * ggml_pad(
|
GGML_API struct ggml_tensor * ggml_pad(
|
||||||
@ -1916,83 +1936,6 @@ extern "C" {
|
|||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
||||||
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
|
|
||||||
|
|
||||||
typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
|
|
||||||
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
|
||||||
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
ggml_unary_op_f32_t fun),
|
|
||||||
"use ggml_map_custom1 instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
ggml_unary_op_f32_t fun),
|
|
||||||
"use ggml_map_custom1_inplace instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b,
|
|
||||||
ggml_binary_op_f32_t fun),
|
|
||||||
"use ggml_map_custom2 instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b,
|
|
||||||
ggml_binary_op_f32_t fun),
|
|
||||||
"use ggml_map_custom2_inplace instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
ggml_custom1_op_f32_t fun),
|
|
||||||
"use ggml_map_custom1 instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
ggml_custom1_op_f32_t fun),
|
|
||||||
"use ggml_map_custom1_inplace instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b,
|
|
||||||
ggml_custom2_op_f32_t fun),
|
|
||||||
"use ggml_map_custom2 instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b,
|
|
||||||
ggml_custom2_op_f32_t fun),
|
|
||||||
"use ggml_map_custom2_inplace instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b,
|
|
||||||
struct ggml_tensor * c,
|
|
||||||
ggml_custom3_op_f32_t fun),
|
|
||||||
"use ggml_map_custom3 instead");
|
|
||||||
|
|
||||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * b,
|
|
||||||
struct ggml_tensor * c,
|
|
||||||
ggml_custom3_op_f32_t fun),
|
|
||||||
"use ggml_map_custom3_inplace instead");
|
|
||||||
|
|
||||||
// custom operators v2
|
|
||||||
|
|
||||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||||
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
|
||||||
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
|
||||||
@ -2048,6 +1991,30 @@ extern "C" {
|
|||||||
int n_tasks,
|
int n_tasks,
|
||||||
void * userdata);
|
void * userdata);
|
||||||
|
|
||||||
|
typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_custom_4d(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
enum ggml_type type,
|
||||||
|
int64_t ne0,
|
||||||
|
int64_t ne1,
|
||||||
|
int64_t ne2,
|
||||||
|
int64_t ne3,
|
||||||
|
struct ggml_tensor ** args,
|
||||||
|
int n_args,
|
||||||
|
ggml_custom_op_t fun,
|
||||||
|
int n_tasks,
|
||||||
|
void * userdata);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_custom_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor ** args,
|
||||||
|
int n_args,
|
||||||
|
ggml_custom_op_t fun,
|
||||||
|
int n_tasks,
|
||||||
|
void * userdata);
|
||||||
|
|
||||||
// loss function
|
// loss function
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
|
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
|
||||||
|
@ -214,7 +214,7 @@ add_library(ggml
|
|||||||
target_link_libraries(ggml PUBLIC ggml-base)
|
target_link_libraries(ggml PUBLIC ggml-base)
|
||||||
|
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
target_link_libraries(ggml PRIVATE dl stdc++fs)
|
target_link_libraries(ggml PRIVATE dl)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
function(ggml_add_backend_library backend)
|
function(ggml_add_backend_library backend)
|
||||||
@ -267,6 +267,7 @@ function(ggml_add_cpu_backend_variant tag_name)
|
|||||||
set(GGML_CPU_TAG_NAME ${tag_name})
|
set(GGML_CPU_TAG_NAME ${tag_name})
|
||||||
# other: OPENMP LLAMAFILE CPU_HBM
|
# other: OPENMP LLAMAFILE CPU_HBM
|
||||||
foreach (feat NATIVE
|
foreach (feat NATIVE
|
||||||
|
SSE42
|
||||||
AVX AVX2 BMI2 AVX_VNNI FMA F16C
|
AVX AVX2 BMI2 AVX_VNNI FMA F16C
|
||||||
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
||||||
AMX_TILE AMX_INT8 AMX_BF16)
|
AMX_TILE AMX_INT8 AMX_BF16)
|
||||||
@ -286,14 +287,16 @@ if (GGML_CPU_ALL_VARIANTS)
|
|||||||
if (NOT GGML_BACKEND_DL)
|
if (NOT GGML_BACKEND_DL)
|
||||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||||
endif()
|
endif()
|
||||||
ggml_add_cpu_backend_variant(sandybridge AVX)
|
ggml_add_cpu_backend_variant(x64)
|
||||||
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 BMI2 FMA)
|
ggml_add_cpu_backend_variant(sse42 SSE42)
|
||||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 BMI2 FMA AVX512)
|
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
||||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
|
||||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
|
||||||
|
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||||
|
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
|
||||||
if (NOT MSVC)
|
if (NOT MSVC)
|
||||||
# MSVC doesn't support AMX
|
# MSVC doesn't support AMX
|
||||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||||
endif()
|
endif()
|
||||||
elseif (GGML_CPU)
|
elseif (GGML_CPU)
|
||||||
ggml_add_cpu_backend_variant_impl("")
|
ggml_add_cpu_backend_variant_impl("")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user