Compare commits
1 Commits
v1.5.3
...
bench-memc
Author | SHA1 | Date | |
---|---|---|---|
ee2971bf6a |
@ -1,38 +0,0 @@
|
|||||||
ARG UBUNTU_VERSION=22.04
|
|
||||||
# This needs to generally match the container host's environment.
|
|
||||||
ARG CUDA_VERSION=12.3.1
|
|
||||||
# Target the CUDA build image
|
|
||||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
|
||||||
# Target the CUDA runtime image
|
|
||||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
|
||||||
|
|
||||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Unless otherwise specified, we make a fat build.
|
|
||||||
ARG CUDA_DOCKER_ARCH=all
|
|
||||||
# Set nvcc architecture
|
|
||||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
|
||||||
# Enable cuBLAS
|
|
||||||
ENV WHISPER_CUBLAS=1
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y build-essential \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
|
||||||
|
|
||||||
# Ref: https://stackoverflow.com/a/53464012
|
|
||||||
ENV CUDA_MAIN_VERSION=12.3
|
|
||||||
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
|
||||||
|
|
||||||
COPY .. .
|
|
||||||
RUN make
|
|
||||||
|
|
||||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y curl ffmpeg \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
|
||||||
|
|
||||||
COPY --from=build /app /app
|
|
||||||
ENTRYPOINT [ "bash", "-c" ]
|
|
@ -1,19 +0,0 @@
|
|||||||
FROM ubuntu:22.04 AS build
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y build-essential \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
|
||||||
|
|
||||||
COPY .. .
|
|
||||||
RUN make
|
|
||||||
|
|
||||||
FROM ubuntu:22.04 AS runtime
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y curl ffmpeg \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
|
||||||
|
|
||||||
COPY --from=build /app /app
|
|
||||||
ENTRYPOINT [ "bash", "-c" ]
|
|
52
.github/workflows/build.yml
vendored
@ -25,7 +25,6 @@ jobs:
|
|||||||
docker run --platform ${{ matrix.arch }} --rm \
|
docker run --platform ${{ matrix.arch }} --rm \
|
||||||
-v ${{ github.workspace }}:/workspace \
|
-v ${{ github.workspace }}:/workspace \
|
||||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||||
set -e
|
|
||||||
apt update
|
apt update
|
||||||
apt install -y build-essential libsdl2-dev
|
apt install -y build-essential libsdl2-dev
|
||||||
make
|
make
|
||||||
@ -87,7 +86,6 @@ jobs:
|
|||||||
docker run --platform ${{ matrix.arch }} --rm \
|
docker run --platform ${{ matrix.arch }} --rm \
|
||||||
-v ${{ github.workspace }}:/workspace \
|
-v ${{ github.workspace }}:/workspace \
|
||||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||||
set -e
|
|
||||||
apt update
|
apt update
|
||||||
apt install -y build-essential cmake libsdl2-dev
|
apt install -y build-essential cmake libsdl2-dev
|
||||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||||
@ -115,9 +113,8 @@ jobs:
|
|||||||
docker run --platform ${{ matrix.arch }} --rm \
|
docker run --platform ${{ matrix.arch }} --rm \
|
||||||
-v ${{ github.workspace }}:/workspace \
|
-v ${{ github.workspace }}:/workspace \
|
||||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||||
set -e
|
|
||||||
apt update
|
apt update
|
||||||
apt install -y clang build-essential cmake libsdl2-dev
|
apt install -y build-essential cmake libsdl2-dev
|
||||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||||
make
|
make
|
||||||
ctest -L gh --output-on-failure'
|
ctest -L gh --output-on-failure'
|
||||||
@ -143,7 +140,6 @@ jobs:
|
|||||||
docker run --platform ${{ matrix.arch }} --rm \
|
docker run --platform ${{ matrix.arch }} --rm \
|
||||||
-v ${{ github.workspace }}:/workspace \
|
-v ${{ github.workspace }}:/workspace \
|
||||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||||
set -e
|
|
||||||
apt update
|
apt update
|
||||||
apt install -y build-essential cmake
|
apt install -y build-essential cmake
|
||||||
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||||
@ -166,7 +162,7 @@ jobs:
|
|||||||
s2arc: x64
|
s2arc: x64
|
||||||
jnaPath: win32-x86-64
|
jnaPath: win32-x86-64
|
||||||
- sdl2: ON
|
- sdl2: ON
|
||||||
s2ver: 2.28.5
|
s2ver: 2.26.0
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@ -221,16 +217,13 @@ jobs:
|
|||||||
sdl2: [ON]
|
sdl2: [ON]
|
||||||
include:
|
include:
|
||||||
- arch: Win32
|
- arch: Win32
|
||||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
|
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
|
||||||
s2arc: x86
|
s2arc: x86
|
||||||
clblast: OFF
|
|
||||||
- arch: x64
|
- arch: x64
|
||||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
|
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
|
||||||
s2arc: x64
|
s2arc: x64
|
||||||
clblast: ON
|
|
||||||
clver: 1.6.1
|
|
||||||
- sdl2: ON
|
- sdl2: ON
|
||||||
s2ver: 2.28.5
|
s2ver: 2.26.0
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@ -255,18 +248,6 @@ jobs:
|
|||||||
7z x sdl2.zip
|
7z x sdl2.zip
|
||||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: Install OpenCL
|
|
||||||
if: matrix.clblast == 'ON'
|
|
||||||
run: vcpkg.exe --triplet=${{ matrix.arch }}-windows install opencl
|
|
||||||
|
|
||||||
- name: Fetch CLBlast and set CLBlast_DIR
|
|
||||||
if: matrix.clblast == 'ON'
|
|
||||||
run: |
|
|
||||||
C:/msys64/usr/bin/wget.exe -qO clblast.zip https://github.com/CNugteren/CLBlast/releases/download/${{ matrix.clver }}/CLBlast-${{ matrix.clver }}-windows-x64.zip
|
|
||||||
7z x clblast.zip
|
|
||||||
7z x CLBlast-${{ matrix.clver }}-windows-x64.7z
|
|
||||||
echo "CLBlast_DIR=$env:GITHUB_WORKSPACE/CLBlast-${{ matrix.clver }}-windows-x64/lib/cmake/CLBlast" >> $env:GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Configure
|
- name: Configure
|
||||||
run: >
|
run: >
|
||||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||||
@ -274,7 +255,6 @@ jobs:
|
|||||||
-DWHISPER_OPENBLAS=${{ matrix.blas }}
|
-DWHISPER_OPENBLAS=${{ matrix.blas }}
|
||||||
-DCMAKE_LIBRARY_PATH="$env:OPENBLAS_PATH/lib"
|
-DCMAKE_LIBRARY_PATH="$env:OPENBLAS_PATH/lib"
|
||||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||||
-DWHISPER_CLBLAST=${{ matrix.clblast }}
|
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
@ -289,15 +269,11 @@ jobs:
|
|||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||||
|
|
||||||
- name: Copy clblast.dll
|
|
||||||
if: matrix.clblast == 'ON'
|
|
||||||
run: copy "$env:CLBlast_DIR/../../clblast.dll" build/bin/${{ matrix.build }}
|
|
||||||
|
|
||||||
- name: Upload binaries
|
- name: Upload binaries
|
||||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||||
uses: actions/upload-artifact@v1
|
uses: actions/upload-artifact@v1
|
||||||
with:
|
with:
|
||||||
name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }}
|
name: whisper-blas-bin-${{ matrix.arch }}
|
||||||
path: build/bin/${{ matrix.build }}
|
path: build/bin/${{ matrix.build }}
|
||||||
|
|
||||||
windows-cublas:
|
windows-cublas:
|
||||||
@ -309,12 +285,11 @@ jobs:
|
|||||||
arch: [x64]
|
arch: [x64]
|
||||||
cublas: [ON]
|
cublas: [ON]
|
||||||
sdl2: [ON]
|
sdl2: [ON]
|
||||||
cuda-toolkit: [12.2.0, 11.8.0]
|
|
||||||
include:
|
include:
|
||||||
- arch: x64
|
- arch: x64
|
||||||
s2arc: x64
|
s2arc: x64
|
||||||
- sdl2: ON
|
- sdl2: ON
|
||||||
s2ver: 2.28.5
|
s2ver: 2.26.0
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
@ -325,9 +300,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install CUDA Toolkit
|
- name: Install CUDA Toolkit
|
||||||
id: cuda-toolkit
|
id: cuda-toolkit
|
||||||
uses: Jimver/cuda-toolkit@v0.2.11
|
uses: Jimver/cuda-toolkit@v0.2.10
|
||||||
with:
|
|
||||||
cuda: '${{ matrix.cuda-toolkit }}'
|
|
||||||
|
|
||||||
- name: Fetch SDL2 and set SDL2_DIR
|
- name: Fetch SDL2 and set SDL2_DIR
|
||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
@ -340,13 +313,12 @@ jobs:
|
|||||||
run: >
|
run: >
|
||||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||||
-DWHISPER_CUBLAS=${{ matrix.cublas }}
|
-DWHISPER_CUBLAS=1
|
||||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
|
||||||
|
|
||||||
- name: Build ${{ matrix.cuda-toolkit }}
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
cd ./build
|
cd ./build
|
||||||
cmake --build . --config ${{ matrix.build }}
|
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||||
|
|
||||||
- name: Copy CUDA DLLs
|
- name: Copy CUDA DLLs
|
||||||
run: >
|
run: >
|
||||||
@ -363,7 +335,7 @@ jobs:
|
|||||||
if: matrix.sdl2 == 'ON'
|
if: matrix.sdl2 == 'ON'
|
||||||
uses: actions/upload-artifact@v1
|
uses: actions/upload-artifact@v1
|
||||||
with:
|
with:
|
||||||
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
|
name: whisper-cublas-bin-${{ matrix.arch }}
|
||||||
path: build/bin/${{ matrix.build }}
|
path: build/bin/${{ matrix.build }}
|
||||||
|
|
||||||
emscripten:
|
emscripten:
|
||||||
|
57
.github/workflows/docker.yml
vendored
@ -1,57 +0,0 @@
|
|||||||
name: Publish Docker image
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
push_to_registry:
|
|
||||||
name: Push Docker image to Docker Hub
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
COMMIT_SHA: ${{ github.sha }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
config:
|
|
||||||
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64,linux/arm64" }
|
|
||||||
- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check out the repo
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v3
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.repository_owner }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Build and push Docker image (versioned)
|
|
||||||
if: github.event_name == 'push'
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
push: true
|
|
||||||
platforms: ${{ matrix.config.platforms }}
|
|
||||||
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
|
|
||||||
file: ${{ matrix.config.dockerfile }}
|
|
||||||
|
|
||||||
- name: Build and push Docker image (tagged)
|
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
push: ${{ github.event_name == 'push' }}
|
|
||||||
platforms: ${{ matrix.config.platforms }}
|
|
||||||
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
|
|
||||||
file: ${{ matrix.config.dockerfile }}
|
|
@ -1,6 +1,6 @@
|
|||||||
cmake_minimum_required (VERSION 3.5)
|
cmake_minimum_required (VERSION 3.5)
|
||||||
|
|
||||||
project(whisper.cpp VERSION 1.5.3)
|
project(whisper.cpp VERSION 1.5.0)
|
||||||
|
|
||||||
# Add path to modules
|
# Add path to modules
|
||||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||||
@ -218,17 +218,11 @@ if (WHISPER_CUBLAS)
|
|||||||
add_compile_definitions(GGML_USE_CUBLAS)
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
|
|
||||||
if (WHISPER_STATIC)
|
if (WHISPER_STATIC)
|
||||||
if (WIN32)
|
|
||||||
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
|
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
|
||||||
else ()
|
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||||
endif()
|
|
||||||
else()
|
else()
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
|
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "cuBLAS not found")
|
message(FATAL_ERROR "cuBLAS not found")
|
||||||
endif()
|
endif()
|
||||||
@ -344,8 +338,8 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
if (EMSCRIPTEN)
|
if (EMSCRIPTEN)
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -s TOTAL_STACK=5242880")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -s TOTAL_STACK=5242880")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
|
||||||
else()
|
else()
|
||||||
if(NOT WHISPER_NO_AVX)
|
if(NOT WHISPER_NO_AVX)
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
|
||||||
@ -527,13 +521,7 @@ endif()
|
|||||||
|
|
||||||
if (GGML_SOURCES_CUDA)
|
if (GGML_SOURCES_CUDA)
|
||||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||||
# Only configure gmml CUDA architectures is not globally set
|
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
if (NOT DEFINED GGML_CUDA_ARCHITECTURES)
|
|
||||||
# Not overriden by user, so set defaults
|
|
||||||
set(GGML_CUDA_ARCHITECTURES 52 61 70)
|
|
||||||
endif()
|
|
||||||
message(STATUS "GGML Configuring CUDA architectures ${GGML_CUDA_ARCHITECTURES}")
|
|
||||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES ${GGML_CUDA_ARCHITECTURES})
|
|
||||||
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -545,7 +533,7 @@ target_compile_definitions(${TARGET} PUBLIC
|
|||||||
${WHISPER_EXTRA_FLAGS}
|
${WHISPER_EXTRA_FLAGS}
|
||||||
)
|
)
|
||||||
|
|
||||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h")
|
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||||
|
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
2
Makefile
@ -206,7 +206,7 @@ ifdef WHISPER_CUBLAS
|
|||||||
|
|
||||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||||
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
|
||||||
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
|
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
|
||||||
WHISPER_OBJ += ggml-cuda.o
|
WHISPER_OBJ += ggml-cuda.o
|
||||||
NVCC = nvcc
|
NVCC = nvcc
|
||||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
|
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
|
||||||
|
@ -2,26 +2,41 @@
|
|||||||
|
|
||||||
import PackageDescription
|
import PackageDescription
|
||||||
|
|
||||||
let package = Package(
|
#if arch(arm) || arch(arm64)
|
||||||
name: "whisper",
|
let platforms: [SupportedPlatform]? = [
|
||||||
platforms: [
|
|
||||||
.macOS(.v12),
|
.macOS(.v12),
|
||||||
.iOS(.v14),
|
.iOS(.v14),
|
||||||
.watchOS(.v4),
|
.watchOS(.v4),
|
||||||
.tvOS(.v14)
|
.tvOS(.v14)
|
||||||
],
|
]
|
||||||
|
let exclude: [String] = []
|
||||||
|
let resources: [Resource] = [
|
||||||
|
.process("ggml-metal.metal")
|
||||||
|
]
|
||||||
|
let additionalSources: [String] = ["ggml-metal.m"]
|
||||||
|
let additionalSettings: [CSetting] = [
|
||||||
|
.unsafeFlags(["-fno-objc-arc"]),
|
||||||
|
.define("GGML_USE_METAL")
|
||||||
|
]
|
||||||
|
#else
|
||||||
|
let platforms: [SupportedPlatform]? = nil
|
||||||
|
let exclude: [String] = ["ggml-metal.metal"]
|
||||||
|
let resources: [Resource] = []
|
||||||
|
let additionalSources: [String] = []
|
||||||
|
let additionalSettings: [CSetting] = []
|
||||||
|
#endif
|
||||||
|
|
||||||
|
let package = Package(
|
||||||
|
name: "whisper",
|
||||||
|
platforms: platforms,
|
||||||
products: [
|
products: [
|
||||||
.library(name: "whisper", targets: ["whisper"]),
|
.library(name: "whisper", targets: ["whisper"]),
|
||||||
],
|
],
|
||||||
dependencies: [
|
|
||||||
.package(url: "https://github.com/ggerganov/ggml.git", .branch("master"))
|
|
||||||
],
|
|
||||||
targets: [
|
targets: [
|
||||||
.target(
|
.target(
|
||||||
name: "whisper",
|
name: "whisper",
|
||||||
dependencies: ["ggml"],
|
|
||||||
path: ".",
|
path: ".",
|
||||||
exclude: [
|
exclude: exclude + [
|
||||||
"bindings",
|
"bindings",
|
||||||
"cmake",
|
"cmake",
|
||||||
"coreml",
|
"coreml",
|
||||||
@ -36,20 +51,23 @@ let package = Package(
|
|||||||
"Makefile"
|
"Makefile"
|
||||||
],
|
],
|
||||||
sources: [
|
sources: [
|
||||||
|
"ggml.c",
|
||||||
"whisper.cpp",
|
"whisper.cpp",
|
||||||
],
|
"ggml-alloc.c",
|
||||||
|
"ggml-backend.c",
|
||||||
|
"ggml-quants.c"
|
||||||
|
] + additionalSources,
|
||||||
|
resources: resources,
|
||||||
publicHeadersPath: "spm-headers",
|
publicHeadersPath: "spm-headers",
|
||||||
cSettings: [
|
cSettings: [
|
||||||
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
|
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
|
||||||
.define("GGML_USE_ACCELERATE"),
|
.define("GGML_USE_ACCELERATE")
|
||||||
.unsafeFlags(["-fno-objc-arc"]),
|
|
||||||
.define("GGML_USE_METAL")
|
|
||||||
// NOTE: NEW_LAPACK will required iOS version 16.4+
|
// NOTE: NEW_LAPACK will required iOS version 16.4+
|
||||||
// We should consider add this in the future when we drop support for iOS 14
|
// We should consider add this in the future when we drop support for iOS 14
|
||||||
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
|
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
|
||||||
// .define("ACCELERATE_NEW_LAPACK"),
|
// .define("ACCELERATE_NEW_LAPACK"),
|
||||||
// .define("ACCELERATE_LAPACK_ILP64")
|
// .define("ACCELERATE_LAPACK_ILP64")
|
||||||
],
|
] + additionalSettings,
|
||||||
linkerSettings: [
|
linkerSettings: [
|
||||||
.linkedFramework("Accelerate")
|
.linkedFramework("Accelerate")
|
||||||
]
|
]
|
||||||
|
43
README.md
@ -6,7 +6,7 @@
|
|||||||
[](https://opensource.org/licenses/MIT)
|
[](https://opensource.org/licenses/MIT)
|
||||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||||
|
|
||||||
Stable: [v1.5.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.3) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
Stable: [v1.5.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||||
|
|
||||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||||
|
|
||||||
@ -33,7 +33,6 @@ Supported platforms:
|
|||||||
- [x] [WebAssembly](examples/whisper.wasm)
|
- [x] [WebAssembly](examples/whisper.wasm)
|
||||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||||
- [x] [docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
|
|
||||||
|
|
||||||
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
|
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
|
||||||
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
|
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
|
||||||
@ -111,8 +110,8 @@ options:
|
|||||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||||
-sow, --split-on-word [false ] split on word rather than on token
|
-sow, --split-on-word [false ] split on word rather than on token
|
||||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||||
-bs N, --beam-size N [5 ] beam size for beam search
|
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||||
@ -129,7 +128,6 @@ options:
|
|||||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||||
-ocsv, --output-csv [false ] output result in a CSV file
|
-ocsv, --output-csv [false ] output result in a CSV file
|
||||||
-oj, --output-json [false ] output result in a JSON file
|
-oj, --output-json [false ] output result in a JSON file
|
||||||
-ojf, --output-json-full [false ] include more information in the JSON file
|
|
||||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||||
-ps, --print-special [false ] print special tokens
|
-ps, --print-special [false ] print special tokens
|
||||||
-pc, --print-colors [false ] print colors
|
-pc, --print-colors [false ] print colors
|
||||||
@ -141,8 +139,7 @@ options:
|
|||||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||||
-f FNAME, --file FNAME [ ] input WAV file path
|
-f FNAME, --file FNAME [ ] input WAV file path
|
||||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||||
-ls, --log-score [false ] log best decoder scores of tokens
|
-ls, --log-score [false ] log best decoder scores of token
|
||||||
-ng, --no-gpu [false ] disable GPU
|
|
||||||
|
|
||||||
|
|
||||||
bash ./models/download-ggml-model.sh base.en
|
bash ./models/download-ggml-model.sh base.en
|
||||||
@ -449,36 +446,6 @@ make clean
|
|||||||
WHISPER_OPENBLAS=1 make -j
|
WHISPER_OPENBLAS=1 make -j
|
||||||
```
|
```
|
||||||
|
|
||||||
## Docker
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
* Docker must be installed and running on your system.
|
|
||||||
* Create a folder to store big models & intermediate files (ex. /whisper/models)
|
|
||||||
|
|
||||||
### Images
|
|
||||||
We have two Docker images available for this project:
|
|
||||||
|
|
||||||
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
|
|
||||||
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# download model and persist it in a local folder
|
|
||||||
docker run -it --rm \
|
|
||||||
-v path/to/models:/models \
|
|
||||||
whisper.cpp:main "./models/download-ggml-model.sh base /models"
|
|
||||||
# transcribe an audio file
|
|
||||||
docker run -it --rm \
|
|
||||||
-v path/to/models:/models \
|
|
||||||
-v path/to/audios:/audios \
|
|
||||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
|
|
||||||
# transcribe an audio file in samples folder
|
|
||||||
docker run -it --rm \
|
|
||||||
-v path/to/models:/models \
|
|
||||||
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- Inference only
|
- Inference only
|
||||||
@ -801,7 +768,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
|||||||
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
|
||||||
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
|
||||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||||
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
|
|
||||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||||
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
|
||||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||||
@ -811,7 +777,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
|||||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||||
| [server](examples/server) | | HTTP transcription server with OAI-like API |
|
|
||||||
|
|
||||||
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
||||||
|
|
||||||
|
@ -1,26 +1,9 @@
|
|||||||
ifndef UNAME_S
|
|
||||||
UNAME_S := $(shell uname -s)
|
|
||||||
endif
|
|
||||||
|
|
||||||
ifndef UNAME_P
|
|
||||||
UNAME_P := $(shell uname -p)
|
|
||||||
endif
|
|
||||||
|
|
||||||
ifndef UNAME_M
|
|
||||||
UNAME_M := $(shell uname -m)
|
|
||||||
endif
|
|
||||||
|
|
||||||
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
|
|
||||||
BUILD_DIR := build
|
BUILD_DIR := build
|
||||||
MODELS_DIR := models
|
MODELS_DIR := models
|
||||||
EXAMPLES_DIR := $(wildcard examples/*)
|
EXAMPLES_DIR := $(wildcard examples/*)
|
||||||
INCLUDE_PATH := $(abspath ../..)
|
INCLUDE_PATH := $(abspath ../..)
|
||||||
LIBRARY_PATH := $(abspath ../..)
|
LIBRARY_PATH := $(abspath ../..)
|
||||||
|
|
||||||
ifeq ($(UNAME_S),Darwin)
|
|
||||||
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
|
|
||||||
endif
|
|
||||||
|
|
||||||
all: clean whisper examples
|
all: clean whisper examples
|
||||||
|
|
||||||
whisper: mkdir
|
whisper: mkdir
|
||||||
@ -28,13 +11,8 @@ whisper: mkdir
|
|||||||
@${MAKE} -C ../.. libwhisper.a
|
@${MAKE} -C ../.. libwhisper.a
|
||||||
|
|
||||||
test: model-small whisper modtidy
|
test: model-small whisper modtidy
|
||||||
ifeq ($(UNAME_S),Darwin)
|
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
|
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
|
|
||||||
else
|
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||||
endif
|
|
||||||
|
|
||||||
examples: $(EXAMPLES_DIR)
|
examples: $(EXAMPLES_DIR)
|
||||||
|
|
||||||
@ -43,11 +21,7 @@ model-small: mkdir examples/go-model-download
|
|||||||
|
|
||||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||||
@echo Build example $(notdir $@)
|
@echo Build example $(notdir $@)
|
||||||
ifeq ($(UNAME_S),Darwin)
|
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
|
|
||||||
else
|
|
||||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||||
endif
|
|
||||||
|
|
||||||
mkdir:
|
mkdir:
|
||||||
@echo Mkdir ${BUILD_DIR}
|
@echo Mkdir ${BUILD_DIR}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "whisper.cpp",
|
"name": "whisper.cpp",
|
||||||
"version": "1.5.3",
|
"version": "1.5.0",
|
||||||
"description": "Whisper speech recognition",
|
"description": "Whisper speech recognition",
|
||||||
"main": "whisper.js",
|
"main": "whisper.js",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
@ -70,7 +70,7 @@ extern "C" {
|
|||||||
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
|
|
||||||
// compute graph without a plan
|
// compute graph without a plan
|
||||||
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
|
||||||
// check if the backend supports an operation
|
// check if the backend supports an operation
|
||||||
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
|
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
|
||||||
|
@ -156,8 +156,8 @@ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_
|
|||||||
backend->iface.graph_plan_compute(backend, plan);
|
backend->iface.graph_plan_compute(backend, plan);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||||
return backend->iface.graph_compute(backend, cgraph);
|
backend->iface.graph_compute(backend, cgraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||||
|
@ -52,7 +52,7 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
|
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
|
||||||
|
|
||||||
// tensor copy between different backends
|
// tensor copy between different backends
|
||||||
|
@ -14,10 +14,6 @@ if (WHISPER_SDL2)
|
|||||||
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (WHISPER_CLBLAST)
|
|
||||||
find_package(CLBlast REQUIRED)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# common
|
# common
|
||||||
|
|
||||||
set(TARGET common)
|
set(TARGET common)
|
||||||
@ -77,5 +73,3 @@ else()
|
|||||||
add_subdirectory(talk-llama)
|
add_subdirectory(talk-llama)
|
||||||
add_subdirectory(lsp)
|
add_subdirectory(lsp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(wchess)
|
|
||||||
|
@ -22,7 +22,6 @@ var printTextarea = (function() {
|
|||||||
async function clearCache() {
|
async function clearCache() {
|
||||||
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
|
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
|
||||||
indexedDB.deleteDatabase(dbName);
|
indexedDB.deleteDatabase(dbName);
|
||||||
location.reload();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,37 +17,28 @@ options:
|
|||||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||||
-sow, --split-on-word [false ] split on word rather than on token
|
|
||||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||||
-bs N, --beam-size N [5 ] beam size for beam search
|
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||||
-tr, --translate [false ] translate from source language to english
|
-tr, --translate [false ] translate from source language to english
|
||||||
-di, --diarize [false ] stereo audio diarization
|
-di, --diarize [false ] stereo audio diarization
|
||||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
|
||||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||||
-otxt, --output-txt [false ] output result in a text file
|
-otxt, --output-txt [false ] output result in a text file
|
||||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||||
-osrt, --output-srt [false ] output result in a srt file
|
-osrt, --output-srt [false ] output result in a srt file
|
||||||
-olrc, --output-lrc [false ] output result in a lrc file
|
|
||||||
-owts, --output-words [false ] output script for generating karaoke video
|
-owts, --output-words [false ] output script for generating karaoke video
|
||||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
|
||||||
-ocsv, --output-csv [false ] output result in a CSV file
|
-ocsv, --output-csv [false ] output result in a CSV file
|
||||||
-oj, --output-json [false ] output result in a JSON file
|
-oj, --output-json [false ] output result in a JSON file
|
||||||
-ojf, --output-json-full [false ] include more information in the JSON file
|
|
||||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||||
-ps, --print-special [false ] print special tokens
|
-ps, --print-special [false ] print special tokens
|
||||||
-pc, --print-colors [false ] print colors
|
-pc, --print-colors [false ] print colors
|
||||||
-pp, --print-progress [false ] print progress
|
-pp, --print-progress [false ] print progress
|
||||||
-nt, --no-timestamps [false ] do not print timestamps
|
-nt, --no-timestamps [true ] do not print timestamps
|
||||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||||
-dl, --detect-language [false ] exit after automatically detecting language
|
|
||||||
--prompt PROMPT [ ] initial prompt
|
--prompt PROMPT [ ] initial prompt
|
||||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||||
-f FNAME, --file FNAME [ ] input WAV file path
|
-f FNAME, --file FNAME [ ] input WAV file path
|
||||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
|
||||||
-ls, --log-score [false ] log best decoder scores of tokens
|
|
||||||
-ng, --no-gpu [false ] disable GPU
|
|
||||||
```
|
```
|
||||||
|
@ -4,9 +4,3 @@ add_executable(${TARGET} server.cpp httplib.h json.hpp)
|
|||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
# Check if the compiler is MinGW
|
|
||||||
if(MINGW)
|
|
||||||
# Link the necessary libraries for SSL and Winsock
|
|
||||||
target_link_libraries(${TARGET} PRIVATE -lcrypt32 -lssl -lcrypto -lws2_32)
|
|
||||||
endif()
|
|
||||||
|
@ -2,10 +2,6 @@
|
|||||||
|
|
||||||
Simple http server. WAV Files are passed to the inference model via http requests.
|
Simple http server. WAV Files are passed to the inference model via http requests.
|
||||||
|
|
||||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
```
|
```
|
||||||
./server -h
|
./server -h
|
||||||
|
|
||||||
@ -33,7 +29,6 @@ options:
|
|||||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||||
-ps, --print-special [false ] print special tokens
|
-ps, --print-special [false ] print special tokens
|
||||||
-pc, --print-colors [false ] print colors
|
-pc, --print-colors [false ] print colors
|
||||||
-pr, --print-realtime [false ] print output in realtime
|
|
||||||
-pp, --print-progress [false ] print progress
|
-pp, --print-progress [false ] print progress
|
||||||
-nt, --no-timestamps [false ] do not print timestamps
|
-nt, --no-timestamps [false ] do not print timestamps
|
||||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||||
@ -43,12 +38,8 @@ options:
|
|||||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||||
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
|
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
|
||||||
--port PORT, [8080 ] Port number for the server
|
--port PORT, [8080 ] Port number for the server
|
||||||
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> [!WARNING]
|
|
||||||
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
|
|
||||||
|
|
||||||
## request examples
|
## request examples
|
||||||
|
|
||||||
**/inference**
|
**/inference**
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
@ -44,8 +43,6 @@ struct server_params
|
|||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
int32_t write_timeout = 600;
|
int32_t write_timeout = 600;
|
||||||
|
|
||||||
bool ffmpeg_converter = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct whisper_params {
|
struct whisper_params {
|
||||||
@ -75,7 +72,6 @@ struct whisper_params {
|
|||||||
bool no_fallback = false;
|
bool no_fallback = false;
|
||||||
bool print_special = false;
|
bool print_special = false;
|
||||||
bool print_colors = false;
|
bool print_colors = false;
|
||||||
bool print_realtime = false;
|
|
||||||
bool print_progress = false;
|
bool print_progress = false;
|
||||||
bool no_timestamps = false;
|
bool no_timestamps = false;
|
||||||
bool use_gpu = true;
|
bool use_gpu = true;
|
||||||
@ -148,7 +144,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||||
fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false");
|
|
||||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
||||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||||
@ -160,7 +155,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
|
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
|
||||||
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
|
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
|
||||||
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
|
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
|
||||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,7 +188,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
|||||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||||
else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; }
|
|
||||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||||
@ -207,7 +200,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
|||||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||||
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
||||||
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
|
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
|
||||||
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
|
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
whisper_print_usage(argc, argv, params, sparams);
|
whisper_print_usage(argc, argv, params, sparams);
|
||||||
@ -225,45 +217,6 @@ struct whisper_print_user_data {
|
|||||||
int progress_prev;
|
int progress_prev;
|
||||||
};
|
};
|
||||||
|
|
||||||
void check_ffmpeg_availibility() {
|
|
||||||
int result = system("ffmpeg -version");
|
|
||||||
|
|
||||||
if (result == 0) {
|
|
||||||
std::cout << "ffmpeg is available." << std::endl;
|
|
||||||
} else {
|
|
||||||
// ffmpeg is not available
|
|
||||||
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed ";
|
|
||||||
std::cout << "and that its executable is included in your system's PATH. ";
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
|
|
||||||
std::ostringstream cmd_stream;
|
|
||||||
std::string converted_filename_temp = temp_filename + "_temp.wav";
|
|
||||||
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
|
|
||||||
std::string cmd = cmd_stream.str();
|
|
||||||
|
|
||||||
int status = std::system(cmd.c_str());
|
|
||||||
if (status != 0) {
|
|
||||||
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the original file
|
|
||||||
if (remove(temp_filename.c_str()) != 0) {
|
|
||||||
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rename the temporary file to match the original filename
|
|
||||||
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
|
|
||||||
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
|
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
|
||||||
std::string speaker = "";
|
std::string speaker = "";
|
||||||
const int64_t n_samples = pcmf32s[0].size();
|
const int64_t n_samples = pcmf32s[0].size();
|
||||||
@ -420,7 +373,7 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
|||||||
{
|
{
|
||||||
params.response_format = req.get_file_value("response-format").content;
|
params.response_format = req.get_file_value("response-format").content;
|
||||||
}
|
}
|
||||||
if (req.has_file("temperature"))
|
if (req.has_file("temerature"))
|
||||||
{
|
{
|
||||||
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
|
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
|
||||||
}
|
}
|
||||||
@ -451,9 +404,6 @@ int main(int argc, char ** argv) {
|
|||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sparams.ffmpeg_converter) {
|
|
||||||
check_ffmpeg_availibility();
|
|
||||||
}
|
|
||||||
// whisper init
|
// whisper init
|
||||||
struct whisper_context_params cparams;
|
struct whisper_context_params cparams;
|
||||||
cparams.use_gpu = params.use_gpu;
|
cparams.use_gpu = params.use_gpu;
|
||||||
@ -469,9 +419,6 @@ int main(int argc, char ** argv) {
|
|||||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||||
|
|
||||||
Server svr;
|
Server svr;
|
||||||
svr.set_default_headers({{"Server", "whisper.cpp"},
|
|
||||||
{"Access-Control-Allow-Origin", "*"},
|
|
||||||
{"Access-Control-Allow-Headers", "content-type"}});
|
|
||||||
|
|
||||||
std::string const default_content = "<html>hello</html>";
|
std::string const default_content = "<html>hello</html>";
|
||||||
|
|
||||||
@ -482,7 +429,7 @@ int main(int argc, char ** argv) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/inference", [&](const Request &req, Response &res){
|
svr.Post("/inference", [&](const Request &req, Response &res){
|
||||||
// acquire whisper model mutex lock
|
// aquire whisper model mutex lock
|
||||||
whisper_mutex.lock();
|
whisper_mutex.lock();
|
||||||
|
|
||||||
// first check user requested fields of the request
|
// first check user requested fields of the request
|
||||||
@ -506,35 +453,20 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||||
|
|
||||||
// write to temporary file
|
// write file to temporary file
|
||||||
const std::string temp_filename = "whisper_server_temp_file.wav";
|
std::ofstream temp_file{filename, std::ios::binary};
|
||||||
std::ofstream temp_file{temp_filename, std::ios::binary};
|
|
||||||
temp_file << audio_file.content;
|
temp_file << audio_file.content;
|
||||||
temp_file.close();
|
|
||||||
|
|
||||||
// if file is not wav, convert to wav
|
|
||||||
|
|
||||||
if (sparams.ffmpeg_converter) {
|
|
||||||
std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
|
|
||||||
const bool is_converted = convert_to_wav(temp_filename, error_resp);
|
|
||||||
if (!is_converted) {
|
|
||||||
res.set_content(error_resp, "application/json");
|
|
||||||
whisper_mutex.unlock();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// read wav content into pcmf32
|
// read wav content into pcmf32
|
||||||
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
|
if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
|
||||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
|
fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
|
||||||
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
|
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
|
||||||
res.set_content(error_resp, "application/json");
|
res.set_content(error_resp, "application/json");
|
||||||
std::remove(temp_filename.c_str());
|
|
||||||
whisper_mutex.unlock();
|
whisper_mutex.unlock();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// remove temp file
|
// remove temp file
|
||||||
std::remove(temp_filename.c_str());
|
std::remove(filename.c_str());
|
||||||
|
|
||||||
printf("Successfully loaded %s\n", filename.c_str());
|
printf("Successfully loaded %s\n", filename.c_str());
|
||||||
|
|
||||||
@ -571,6 +503,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// run the inference
|
// run the inference
|
||||||
{
|
{
|
||||||
|
|
||||||
printf("Running whisper.cpp inference on %s\n", filename.c_str());
|
printf("Running whisper.cpp inference on %s\n", filename.c_str());
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
@ -589,7 +522,6 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.duration_ms = params.duration_ms;
|
wparams.duration_ms = params.duration_ms;
|
||||||
|
|
||||||
wparams.thold_pt = params.word_thold;
|
wparams.thold_pt = params.word_thold;
|
||||||
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
|
|
||||||
wparams.split_on_word = params.split_on_word;
|
wparams.split_on_word = params.split_on_word;
|
||||||
|
|
||||||
wparams.speed_up = params.speed_up;
|
wparams.speed_up = params.speed_up;
|
||||||
@ -609,7 +541,7 @@ int main(int argc, char ** argv) {
|
|||||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||||
|
|
||||||
// this callback is called on each new segment
|
// this callback is called on each new segment
|
||||||
if (params.print_realtime) {
|
if (!wparams.print_realtime) {
|
||||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||||
wparams.new_segment_callback_user_data = &user_data;
|
wparams.new_segment_callback_user_data = &user_data;
|
||||||
}
|
}
|
||||||
@ -659,50 +591,6 @@ int main(int argc, char ** argv) {
|
|||||||
std::string results = output_str(ctx, params, pcmf32s);
|
std::string results = output_str(ctx, params, pcmf32s);
|
||||||
res.set_content(results.c_str(), "text/html");
|
res.set_content(results.c_str(), "text/html");
|
||||||
}
|
}
|
||||||
else if (params.response_format == srt_format)
|
|
||||||
{
|
|
||||||
std::stringstream ss;
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
|
||||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
|
||||||
std::string speaker = "";
|
|
||||||
|
|
||||||
if (params.diarize && pcmf32s.size() == 2)
|
|
||||||
{
|
|
||||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
|
||||||
}
|
|
||||||
|
|
||||||
ss << i + 1 + params.offset_n << "\n";
|
|
||||||
ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
|
|
||||||
ss << speaker << text << "\n\n";
|
|
||||||
}
|
|
||||||
res.set_content(ss.str(), "application/x-subrip");
|
|
||||||
} else if (params.response_format == vtt_format) {
|
|
||||||
std::stringstream ss;
|
|
||||||
|
|
||||||
ss << "WEBVTT\n\n";
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
|
||||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
|
||||||
std::string speaker = "";
|
|
||||||
|
|
||||||
if (params.diarize && pcmf32s.size() == 2)
|
|
||||||
{
|
|
||||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
|
|
||||||
speaker.insert(0, "<v Speaker");
|
|
||||||
speaker.append(">");
|
|
||||||
}
|
|
||||||
|
|
||||||
ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
|
|
||||||
ss << speaker << text << "\n\n";
|
|
||||||
}
|
|
||||||
res.set_content(ss.str(), "text/vtt");
|
|
||||||
}
|
|
||||||
// TODO add more output formats
|
// TODO add more output formats
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -1,18 +1,25 @@
|
|||||||
if (WHISPER_SDL2)
|
if (WHISPER_SDL2)
|
||||||
# talk-llama
|
# talk-llama
|
||||||
set(TARGET talk-llama)
|
set(TARGET talk-llama)
|
||||||
add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
|
||||||
|
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
|
||||||
if (WHISPER_CLBLAST)
|
# TODO: this is temporary
|
||||||
set(CLBLAST_LIBNAME clblast)
|
# need to export ggml symbols for MSVC, but too lazy ..
|
||||||
endif ()
|
add_executable(${TARGET}
|
||||||
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CLBLAST_LIBNAME} ${CMAKE_THREAD_LIBS_INIT})
|
talk-llama.cpp
|
||||||
|
llama.cpp
|
||||||
|
../common.cpp
|
||||||
|
../common-sdl.cpp
|
||||||
|
../../ggml.c
|
||||||
|
../../ggml-alloc.c
|
||||||
|
../../ggml-backend.c
|
||||||
|
../../ggml-quants.c
|
||||||
|
../../whisper.cpp)
|
||||||
|
|
||||||
if(WIN32)
|
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||||
# It requires Windows 8.1 or later for PrefetchVirtualMemory
|
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_definitions(${TARGET} PRIVATE -D_WIN32_WINNT=0x0602)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -39,11 +39,10 @@
|
|||||||
|
|
||||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
#define LLAMA_MAX_RNG_STATE (64*1024)
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 3
|
#define LLAMA_SESSION_VERSION 2
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||||
@ -127,7 +126,7 @@ extern "C" {
|
|||||||
bool sorted;
|
bool sorted;
|
||||||
} llama_token_data_array;
|
} llama_token_data_array;
|
||||||
|
|
||||||
typedef bool (*llama_progress_callback)(float progress, void *ctx);
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||||
|
|
||||||
// Input data for llama_decode
|
// Input data for llama_decode
|
||||||
// A llama_batch object can contain input about one or many sequences
|
// A llama_batch object can contain input about one or many sequences
|
||||||
@ -159,38 +158,16 @@ extern "C" {
|
|||||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||||
} llama_batch;
|
} llama_batch;
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
|
||||||
LLAMA_KV_OVERRIDE_INT,
|
|
||||||
LLAMA_KV_OVERRIDE_FLOAT,
|
|
||||||
LLAMA_KV_OVERRIDE_BOOL,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_model_kv_override {
|
|
||||||
char key[128];
|
|
||||||
enum llama_model_kv_override_type tag;
|
|
||||||
union {
|
|
||||||
int64_t int_value;
|
|
||||||
double float_value;
|
|
||||||
bool bool_value;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_model_params {
|
struct llama_model_params {
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||||
|
|
||||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
// called with a progress value between 0 and 1, pass NULL to disable
|
||||||
// If the provided progress_callback returns true, model loading continues.
|
|
||||||
// If it returns false, model loading is immediately aborted.
|
|
||||||
llama_progress_callback progress_callback;
|
llama_progress_callback progress_callback;
|
||||||
|
|
||||||
// context pointer passed to the progress callback
|
// context pointer passed to the progress callback
|
||||||
void * progress_callback_user_data;
|
void * progress_callback_user_data;
|
||||||
|
|
||||||
// override key-value pairs of the model meta data
|
|
||||||
const struct llama_model_kv_override * kv_overrides;
|
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
bool use_mmap; // use mmap if possible
|
bool use_mmap; // use mmap if possible
|
||||||
@ -208,20 +185,17 @@ extern "C" {
|
|||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||||
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
|
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
|
||||||
float yarn_attn_factor; // YaRN magnitude scaling factor
|
float yarn_attn_factor; // YaRN magnitude scaling factor
|
||||||
float yarn_beta_fast; // YaRN low correction dim
|
float yarn_beta_fast; // YaRN low correction dim
|
||||||
float yarn_beta_slow; // YaRN high correction dim
|
float yarn_beta_slow; // YaRN high correction dim
|
||||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||||
|
|
||||||
enum ggml_type type_k; // data type for K cache
|
|
||||||
enum ggml_type type_v; // data type for V cache
|
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
|
||||||
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
@ -316,9 +290,7 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||||
|
|
||||||
// TODO: become more consistent with returned int types across the API
|
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
|
||||||
|
|
||||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||||
|
|
||||||
@ -329,23 +301,6 @@ extern "C" {
|
|||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||||
|
|
||||||
// Functions to access the model's GGUF metadata scalar values
|
|
||||||
// - The functions return the length of the string on success, or -1 on failure
|
|
||||||
// - The output string is always null-terminated and cleared on failure
|
|
||||||
// - GGUF array values are not supported by these functions
|
|
||||||
|
|
||||||
// Get metadata value as a string by key name
|
|
||||||
LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
|
|
||||||
|
|
||||||
// Get the number of metadata key/value pairs
|
|
||||||
LLAMA_API int llama_model_meta_count(const struct llama_model * model);
|
|
||||||
|
|
||||||
// Get metadata key name by index
|
|
||||||
LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
|
|
||||||
|
|
||||||
// Get metadata value as a string by index
|
|
||||||
LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
|
|
||||||
|
|
||||||
// Get a string describing the model type
|
// Get a string describing the model type
|
||||||
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||||
|
|
||||||
@ -389,60 +344,9 @@ extern "C" {
|
|||||||
// KV cache
|
// KV cache
|
||||||
//
|
//
|
||||||
|
|
||||||
// Information associated with an individual cell in the KV cache view.
|
// Returns the number of tokens in the KV cache
|
||||||
struct llama_kv_cache_view_cell {
|
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||||
// The position for this cell. Takes KV cache shifts into account.
|
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||||
// May be negative if the cell is not populated.
|
|
||||||
llama_pos pos;
|
|
||||||
};
|
|
||||||
|
|
||||||
// An updateable view of the KV cache.
|
|
||||||
struct llama_kv_cache_view {
|
|
||||||
// Number of KV cache cells. This will be the same as the context size.
|
|
||||||
int32_t n_cells;
|
|
||||||
|
|
||||||
// Maximum number of sequences that can exist in a cell. It's not an error
|
|
||||||
// if there are more sequences in a cell than this value, however they will
|
|
||||||
// not be visible in the view cells_sequences.
|
|
||||||
int32_t n_max_seq;
|
|
||||||
|
|
||||||
// Number of tokens in the cache. For example, if there are two populated
|
|
||||||
// cells, the first with 1 sequence id in it and the second with 2 sequence
|
|
||||||
// ids then you'll have 3 tokens.
|
|
||||||
int32_t token_count;
|
|
||||||
|
|
||||||
// Number of populated cache cells.
|
|
||||||
int32_t used_cells;
|
|
||||||
|
|
||||||
// Maximum contiguous empty slots in the cache.
|
|
||||||
int32_t max_contiguous;
|
|
||||||
|
|
||||||
// Index to the start of the max_contiguous slot range. Can be negative
|
|
||||||
// when cache is full.
|
|
||||||
int32_t max_contiguous_idx;
|
|
||||||
|
|
||||||
// Information for an individual cell.
|
|
||||||
struct llama_kv_cache_view_cell * cells;
|
|
||||||
|
|
||||||
// The sequences for each cell. There will be n_max_seq items per cell.
|
|
||||||
llama_seq_id * cells_sequences;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create an empty KV cache view. (use only for debugging purposes)
|
|
||||||
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
|
|
||||||
|
|
||||||
// Free a KV cache view. (use only for debugging purposes)
|
|
||||||
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
|
|
||||||
|
|
||||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
|
||||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
|
||||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
|
||||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
|
||||||
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Clear the KV cache
|
// Clear the KV cache
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API void llama_kv_cache_clear(
|
||||||
@ -613,12 +517,6 @@ extern "C" {
|
|||||||
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
|
||||||
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
|
||||||
|
|
||||||
// Returns -1 if unknown, 1 for true or 0 for false.
|
|
||||||
LLAMA_API int llama_add_bos_token(const struct llama_model * model);
|
|
||||||
|
|
||||||
// Returns -1 if unknown, 1 for true or 0 for false.
|
|
||||||
LLAMA_API int llama_add_eos_token(const struct llama_model * model);
|
|
||||||
|
|
||||||
// codellama infill tokens
|
// codellama infill tokens
|
||||||
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
|
||||||
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
|
||||||
|
@ -282,6 +282,7 @@ int main(int argc, char ** argv) {
|
|||||||
// tune these to your liking
|
// tune these to your liking
|
||||||
lcparams.n_ctx = 2048;
|
lcparams.n_ctx = 2048;
|
||||||
lcparams.seed = 1;
|
lcparams.seed = 1;
|
||||||
|
lcparams.f16_kv = true;
|
||||||
lcparams.n_threads = params.n_threads;
|
lcparams.n_threads = params.n_threads;
|
||||||
|
|
||||||
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
|
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
|
||||||
|
@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|||||||
const int n_ctx = hparams.n_ctx;
|
const int n_ctx = hparams.n_ctx;
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
|
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
|
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||||
|
|
||||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
|
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||||
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
|
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
|
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
|
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
|
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
|
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
|
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
|
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
|
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
|
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
|
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||||
|
|
||||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
|
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
|
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||||
|
|
||||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||||
|
|
||||||
@ -524,7 +524,8 @@ bool gpt2_eval(
|
|||||||
struct ggml_tensor * KQ_scaled =
|
struct ggml_tensor * KQ_scaled =
|
||||||
ggml_scale(ctx0,
|
ggml_scale(ctx0,
|
||||||
KQ,
|
KQ,
|
||||||
1.0f/sqrt(float(n_embd)/n_head));
|
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||||
|
);
|
||||||
|
|
||||||
// KQ_masked = mask_past(KQ_scaled)
|
// KQ_masked = mask_past(KQ_scaled)
|
||||||
// [n_past + N, N, 12]
|
// [n_past + N, N, 12]
|
||||||
|
@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|||||||
const int n_ctx = hparams.n_ctx;
|
const int n_ctx = hparams.n_ctx;
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
|
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
|
||||||
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
|
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
|
||||||
|
|
||||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
|
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
|
||||||
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
|
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
|
||||||
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
|
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
|
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
|
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
|
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
|
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
|
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
|
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
|
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
|
||||||
|
|
||||||
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
|
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
|
||||||
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
|
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
|
||||||
|
|
||||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
|
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
|
||||||
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
|
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
|
||||||
|
|
||||||
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
ctx_size += (6 + 12*n_layer)*256; // object overhead
|
||||||
|
|
||||||
@ -525,7 +525,8 @@ bool gpt2_eval(
|
|||||||
struct ggml_tensor * KQ_scaled =
|
struct ggml_tensor * KQ_scaled =
|
||||||
ggml_scale(ctx0,
|
ggml_scale(ctx0,
|
||||||
KQ,
|
KQ,
|
||||||
1.0f/sqrt(float(n_embd)/n_head));
|
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||||
|
);
|
||||||
|
|
||||||
// KQ_masked = mask_past(KQ_scaled)
|
// KQ_masked = mask_past(KQ_scaled)
|
||||||
// [n_past + N, N, 12]
|
// [n_past + N, N, 12]
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
set(CMAKE_CXX_STANDARD 11)
|
|
||||||
|
|
||||||
add_subdirectory(libwchess)
|
|
||||||
|
|
||||||
if (EMSCRIPTEN)
|
|
||||||
add_subdirectory(wchess.wasm)
|
|
||||||
else()
|
|
||||||
add_subdirectory(wchess.cmd)
|
|
||||||
endif()
|
|
@ -1,45 +0,0 @@
|
|||||||
# wchess
|
|
||||||
|
|
||||||
Voice-controlled chess using Whisper
|
|
||||||
|
|
||||||
Online demo: https://whisper.ggerganov.com/wchess/
|
|
||||||
|
|
||||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa
|
|
||||||
|
|
||||||
## Command-line tool
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir build && cd build
|
|
||||||
cmake -DWHISPER_SDL2=1 ..
|
|
||||||
make -j
|
|
||||||
|
|
||||||
./bin/wchess -m ../models/ggml-base.en.bin
|
|
||||||
|
|
||||||
Move: start
|
|
||||||
|
|
||||||
a b c d e f g h
|
|
||||||
r n b q k b n r 8
|
|
||||||
p p p p p p p p 7
|
|
||||||
. * . * . * . * 6
|
|
||||||
* . * . * . * . 5
|
|
||||||
. * . * . * . * 4
|
|
||||||
* . * . * . * . 3
|
|
||||||
P P P P P P P P 2
|
|
||||||
R N B Q K B N R 1
|
|
||||||
|
|
||||||
White's turn
|
|
||||||
[(l)isten/(p)ause/(q)uit]:
|
|
||||||
```
|
|
||||||
|
|
||||||
## TODO
|
|
||||||
|
|
||||||
- Fix bugs in the chess moves logic
|
|
||||||
- Improve web-browser audio capture - sometimes it does not record the voice properly
|
|
||||||
- Add support for more languages by making the generated grammar string multilingual
|
|
||||||
- Explore ways to improve the dynamic grammar to be narrower
|
|
||||||
|
|
||||||
PRs welcome!
|
|
||||||
|
|
||||||
## Thanks
|
|
||||||
|
|
||||||
- [chessboardjs](https://chessboardjs.com) for the neat chessboard JS library used in this demo
|
|
@ -1,19 +0,0 @@
|
|||||||
add_library(wchess-core STATIC
|
|
||||||
WChess.cpp
|
|
||||||
WChess.h
|
|
||||||
Chessboard.cpp
|
|
||||||
Chessboard.h
|
|
||||||
)
|
|
||||||
|
|
||||||
target_link_libraries(wchess-core
|
|
||||||
PUBLIC
|
|
||||||
whisper
|
|
||||||
common
|
|
||||||
)
|
|
||||||
|
|
||||||
target_include_directories(wchess-core
|
|
||||||
PUBLIC
|
|
||||||
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
|
|
||||||
)
|
|
||||||
|
|
||||||
# add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)
|
|
@ -1,803 +0,0 @@
|
|||||||
#include "Chessboard.h"
|
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstring>
|
|
||||||
#include <set>
|
|
||||||
#include <list>
|
|
||||||
#include <chrono>
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
constexpr std::array<const char*, 64> positions = {
|
|
||||||
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
|
|
||||||
"a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2",
|
|
||||||
"a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3",
|
|
||||||
"a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4",
|
|
||||||
"a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5",
|
|
||||||
"a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6",
|
|
||||||
"a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7",
|
|
||||||
"a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8",
|
|
||||||
};
|
|
||||||
constexpr char INVALID_POS = positions.size();
|
|
||||||
constexpr int R = 0; // rank index
|
|
||||||
constexpr int F = 1; // file index
|
|
||||||
#define FILE (c[F] - '1')
|
|
||||||
#define RANK (c[R] - 'a')
|
|
||||||
constexpr char operator ""_P(const char * c, size_t size) {
|
|
||||||
return size < 2 || RANK < 0 || RANK > 7 ||
|
|
||||||
FILE < 0 || FILE > 7 ? INVALID_POS : FILE * 8 + RANK;
|
|
||||||
}
|
|
||||||
#undef FILE
|
|
||||||
#undef RANK
|
|
||||||
|
|
||||||
struct sview {
|
|
||||||
const char * ptr = nullptr;
|
|
||||||
size_t size = 0;
|
|
||||||
|
|
||||||
sview() = default;
|
|
||||||
sview(const char * p, size_t s) : ptr(p), size(s) {}
|
|
||||||
sview(const std::string& s) : ptr(s.data()), size(s.size()) {}
|
|
||||||
|
|
||||||
size_t find(char del, size_t pos) {
|
|
||||||
while (pos < size && ptr[pos] != del) ++pos;
|
|
||||||
return pos < size ? pos : std::string::npos;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<sview> split(sview str, char del) {
|
|
||||||
std::vector<sview> res;
|
|
||||||
size_t cur = 0;
|
|
||||||
size_t last = 0;
|
|
||||||
while (cur != std::string::npos) {
|
|
||||||
if (str.ptr[last] == ' ') {
|
|
||||||
++last;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
cur = str.find(del, last);
|
|
||||||
size_t len = cur == std::string::npos ? str.size - last : cur - last;
|
|
||||||
res.emplace_back(str.ptr + last, len);
|
|
||||||
last = cur + 1;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
char strToPos(sview str) {
|
|
||||||
return operator ""_P(str.ptr, str.size);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr std::array<const char*, 6> pieceNames = {
|
|
||||||
"pawn", "knight", "bishop", "rook", "queen", "king",
|
|
||||||
};
|
|
||||||
|
|
||||||
static constexpr std::array<char, 6> blackShort = {
|
|
||||||
'p', 'n', 'b', 'r', 'q', 'k',
|
|
||||||
};
|
|
||||||
static constexpr std::array<char, 6> whiteShort = {
|
|
||||||
'P', 'N', 'B', 'R', 'Q', 'K',
|
|
||||||
};
|
|
||||||
|
|
||||||
char strToType(sview str) {
|
|
||||||
auto it = std::find_if(pieceNames.begin(), pieceNames.end(), [str] (const char* name) { return strncmp(name, str.ptr, str.size) == 0; });
|
|
||||||
return it != pieceNames.end() ? it - pieceNames.begin() : pieceNames.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// directions
|
|
||||||
using Direction = std::array<char, 2>;
|
|
||||||
|
|
||||||
constexpr Direction N = {(char) 0, (char) 1};
|
|
||||||
constexpr Direction NNE = {(char) 1, (char) 2};
|
|
||||||
constexpr Direction NE = {(char) 1, (char) 1};
|
|
||||||
constexpr Direction ENE = {(char) 2, (char) 1};
|
|
||||||
constexpr Direction E = {(char) 1, (char) 0};
|
|
||||||
constexpr Direction ESE = {(char) 2, (char) -1};
|
|
||||||
constexpr Direction SE = {(char) 1, (char) -1};
|
|
||||||
constexpr Direction SSE = {(char) 1, (char) -2};
|
|
||||||
constexpr Direction S = {(char) 0, (char) -1};
|
|
||||||
constexpr Direction SSW = {(char) -1, (char) -2};
|
|
||||||
constexpr Direction SW = {(char) -1, (char) -1};
|
|
||||||
constexpr Direction WSW = {(char) -2, (char) -1};
|
|
||||||
constexpr Direction W = {(char) -1, (char) 0};
|
|
||||||
constexpr Direction WNW = {(char) -2, (char) 1};
|
|
||||||
constexpr Direction NW = {(char) -1, (char) 1};
|
|
||||||
constexpr Direction NNW = {(char) -1, (char) 2};
|
|
||||||
|
|
||||||
char makeStep(char pos, const Direction& d) {
|
|
||||||
char next[2] = { char(positions[pos][R] + d[R]) , char(positions[pos][F] + d[F]) };
|
|
||||||
return strToPos(sview{next, sizeof(next)});
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class Modifier>
|
|
||||||
char traverse(char pos, const Direction& d, const Modifier& m, int count = 8) {
|
|
||||||
while (--count >= 0) {
|
|
||||||
pos = makeStep(pos, d);
|
|
||||||
if (pos == INVALID_POS || m(pos)) break;
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
Direction normalize(const Direction& distance) {
|
|
||||||
//return {char((distance[R] > 0) - (distance[R] < 0)), char((distance[F] > 0) - (distance[F] < 0))};
|
|
||||||
const int drp = distance[R] > 0 ? 1 : 0;
|
|
||||||
const int drn = distance[R] < 0 ? 1 : 0;
|
|
||||||
const int dfp = distance[F] > 0 ? 1 : 0;
|
|
||||||
const int dfn = distance[F] < 0 ? 1 : 0;
|
|
||||||
return {char(drp - drn), char(dfp - dfn)};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Pin {
|
|
||||||
Direction d;
|
|
||||||
Piece* pinner;
|
|
||||||
Piece* pinned;
|
|
||||||
};
|
|
||||||
using Pins = std::list<Pin>;
|
|
||||||
using Board = std::array<Piece*, 64>;
|
|
||||||
|
|
||||||
std::vector<Direction> filter(const Direction& pin, std::initializer_list<Direction> directions) {
|
|
||||||
if (pin[R] == 0 && pin[F] == 0) return directions;
|
|
||||||
std::vector<Direction> result;
|
|
||||||
for (auto& d : directions) {
|
|
||||||
if ((d[R] == pin[R] || d[R] == -pin[R]) && (d[F] == pin[F] || d[F] == -pin[F])) result.push_back(d);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class Piece {
|
|
||||||
public:
|
|
||||||
enum Types : char {
|
|
||||||
Pawn,
|
|
||||||
Knight,
|
|
||||||
Bishop,
|
|
||||||
Rook,
|
|
||||||
Queen,
|
|
||||||
King,
|
|
||||||
//
|
|
||||||
NUM_PIECES
|
|
||||||
};
|
|
||||||
|
|
||||||
enum Colors : char {
|
|
||||||
White,
|
|
||||||
Black,
|
|
||||||
};
|
|
||||||
|
|
||||||
const char* name() const;
|
|
||||||
char initial() const;
|
|
||||||
Types type() const { return m_type; }
|
|
||||||
Colors color() const { return m_color; }
|
|
||||||
char pos() const { return m_pos; }
|
|
||||||
void setPos(char pos) {
|
|
||||||
m_pos = pos;
|
|
||||||
invalidate();
|
|
||||||
}
|
|
||||||
const char* coord() const;
|
|
||||||
const std::set<char>& allowed() const { return m_allowed; }
|
|
||||||
bool canReach(char pos) const;
|
|
||||||
virtual bool movePattern(char pos) const = 0;
|
|
||||||
void take();
|
|
||||||
virtual void reinit(const State& state) = 0;
|
|
||||||
void invalidate();
|
|
||||||
protected:
|
|
||||||
Piece(Types type, Colors color, char pos, std::set<char> allowed)
|
|
||||||
: m_type(type), m_color(color), m_pos(pos), m_allowed(std::move(allowed)) {}
|
|
||||||
Piece(const Piece&) = delete;
|
|
||||||
~Piece() = default;
|
|
||||||
|
|
||||||
const Types m_type;
|
|
||||||
const Colors m_color;
|
|
||||||
char m_pos;
|
|
||||||
std::set<char> m_allowed;
|
|
||||||
bool m_update = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Pawn : public Piece {
|
|
||||||
Pawn(Colors color, char pos, std::set<char> next) : Piece(Types::Pawn, color, pos, std::move(next)) {}
|
|
||||||
|
|
||||||
bool is_first_move() const {
|
|
||||||
return m_color ? coord()[F] == '7' : coord()[F] == '2';
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool movePattern(char pos) const override {
|
|
||||||
if (m_pos == INVALID_POS) return false;
|
|
||||||
auto cur = coord();
|
|
||||||
auto next = positions[pos];
|
|
||||||
Direction distance = {char(next[R] - cur[R]), char(next[F] - cur[F])};
|
|
||||||
char forward = m_color ? -1 : 1;
|
|
||||||
return (forward == distance[F] && distance[R] * distance[R] <= 1)
|
|
||||||
|| (is_first_move() && 2 * forward == distance[F] && distance[R] == 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void reinit(const State& state) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Knight : public Piece {
|
|
||||||
Knight(Colors color, char pos, std::set<char> next) : Piece(Types::Knight, color, pos, std::move(next)) {}
|
|
||||||
|
|
||||||
virtual bool movePattern(char pos) const override {
|
|
||||||
if (m_pos == INVALID_POS) return false;
|
|
||||||
auto cur = coord();
|
|
||||||
auto next = positions[pos];
|
|
||||||
Direction diff = {char(next[R] - cur[R]), char(next[F] - cur[F])};
|
|
||||||
return diff[R]*diff[R] + diff[F]*diff[F] == 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void reinit(const State& state) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Bishop : public Piece {
|
|
||||||
Bishop(Colors color, char pos) : Piece(Types::Bishop, color, pos, {}) {}
|
|
||||||
|
|
||||||
virtual bool movePattern(char pos) const override {
|
|
||||||
if (m_pos == INVALID_POS) return false;
|
|
||||||
auto cur = coord();
|
|
||||||
auto next = positions[pos];
|
|
||||||
return cur[R] - cur[F] == next[R] - next[F] || cur[R] + cur[F] == next[R] + next[F];
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void reinit(const State& state) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Rook : public Piece {
|
|
||||||
Rook(Colors color, char pos) : Piece(Types::Rook, color, pos, {}) {}
|
|
||||||
|
|
||||||
virtual bool movePattern(char pos) const override {
|
|
||||||
if (m_pos == INVALID_POS) return false;
|
|
||||||
auto cur = coord();
|
|
||||||
auto next = positions[pos];
|
|
||||||
return cur[R] == next[R] || cur[F] == next[F];
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void reinit(const State& state) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Queen : public Piece {
|
|
||||||
Queen(Colors color, char pos) : Piece(Types::Queen, color, pos, {}) {}
|
|
||||||
|
|
||||||
virtual bool movePattern(char pos) const override {
|
|
||||||
if (m_pos == INVALID_POS) return false;
|
|
||||||
auto cur = coord();
|
|
||||||
auto next = positions[pos];
|
|
||||||
return cur[R] == next[R] || cur[F] == next[F] || cur[R] - cur[F] == next[R] - next[F] || cur[R] + cur[F] == next[R] + next[F];
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void reinit(const State& state) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct King : public Piece {
|
|
||||||
King(Colors color, char pos) : Piece(Types::King, color, pos, {}) {}
|
|
||||||
|
|
||||||
virtual bool movePattern(char pos) const override {
|
|
||||||
if (m_pos == INVALID_POS) return false;
|
|
||||||
auto cur = coord();
|
|
||||||
auto next = positions[pos];
|
|
||||||
Direction diff = {char(next[R] - cur[R]), char(next[F] - cur[F])};
|
|
||||||
return diff[R]*diff[R] + diff[F]*diff[F] <= 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual void reinit(const State& state) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PieceSet {
|
|
||||||
Piece* begin() { return &p1; }
|
|
||||||
Piece* end() { return &r2 + 1; }
|
|
||||||
const Piece* begin() const { return &p1; }
|
|
||||||
const Piece* end() const { return &r2 + 1; }
|
|
||||||
Piece& operator[](int i) { return *(begin() + i); }
|
|
||||||
const Piece& operator[](int i) const { return *(begin() + i); }
|
|
||||||
|
|
||||||
Pawn p1;
|
|
||||||
Pawn p2;
|
|
||||||
Pawn p3;
|
|
||||||
Pawn p4;
|
|
||||||
Pawn p5;
|
|
||||||
Pawn p6;
|
|
||||||
Pawn p7;
|
|
||||||
Pawn p8;
|
|
||||||
Rook r1;
|
|
||||||
Knight n1;
|
|
||||||
Bishop b1;
|
|
||||||
Queen q;
|
|
||||||
King k;
|
|
||||||
Bishop b2;
|
|
||||||
Knight n2;
|
|
||||||
Rook r2;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct State {
|
|
||||||
State();
|
|
||||||
PieceSet blacks;
|
|
||||||
PieceSet whites;
|
|
||||||
Board board;
|
|
||||||
Pins blackPins;
|
|
||||||
Pins whitePins;
|
|
||||||
};
|
|
||||||
|
|
||||||
Direction findPin(const Piece& piece, const State& state) {
|
|
||||||
auto& pins = piece.color() ? state.blackPins : state.whitePins;
|
|
||||||
auto it = std::find_if(pins.begin(), pins.end(), [&] (const Pin& pin) { return pin.pinned == &piece; });
|
|
||||||
if (it != pins.end()) return it->d;
|
|
||||||
return {0, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Find {
|
|
||||||
Find(const Board& board) : m_board(board) {}
|
|
||||||
bool operator() (char pos) const { return m_board[pos]; }
|
|
||||||
const Board& m_board;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Add {
|
|
||||||
Add(const Board& board, std::set<char>& moves, Piece::Colors color) : m_board(board), m_moves(moves), m_color(color) {}
|
|
||||||
bool operator() (char pos) const {
|
|
||||||
if (!m_board[pos] || m_board[pos]->color() != m_color) m_moves.insert(pos);
|
|
||||||
return m_board[pos];
|
|
||||||
}
|
|
||||||
const Board& m_board;
|
|
||||||
std::set<char>& m_moves;
|
|
||||||
Piece::Colors m_color;
|
|
||||||
};
|
|
||||||
|
|
||||||
void Pawn::reinit(const State& state) {
|
|
||||||
if (m_pos == INVALID_POS) return;
|
|
||||||
if (!m_update) return;
|
|
||||||
m_update = false;
|
|
||||||
m_allowed.clear();
|
|
||||||
|
|
||||||
auto pin = findPin(*this, state);
|
|
||||||
|
|
||||||
auto & left = m_color ? SW : NW;
|
|
||||||
auto & right = m_color ? SE : NE;
|
|
||||||
|
|
||||||
for (auto& direction : filter(pin, { left, right })) {
|
|
||||||
auto pos = makeStep(m_pos, direction);
|
|
||||||
if (pos != INVALID_POS && state.board[pos] && state.board[pos]->color() != m_color) m_allowed.insert(pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto & forward = m_color ? S : N;
|
|
||||||
if (!filter(pin, {forward}).empty()) {
|
|
||||||
traverse(m_pos, forward, [&] (char pos) {
|
|
||||||
if (!state.board[pos]) m_allowed.insert(pos);
|
|
||||||
return state.board[pos] || !is_first_move();
|
|
||||||
}, 2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Knight::reinit(const State& state) {
|
|
||||||
if (m_pos == INVALID_POS) return;
|
|
||||||
if (!m_update) return;
|
|
||||||
m_update = false;
|
|
||||||
m_allowed.clear();
|
|
||||||
auto pin = findPin(*this, state);
|
|
||||||
if (pin[R] != 0 || pin[F] != 0) return;
|
|
||||||
for (auto& direction : { NNE, ENE, ESE, SSE, SSW, WSW, WNW, NNW }) {
|
|
||||||
auto pos = makeStep(m_pos, direction);
|
|
||||||
if (pos != INVALID_POS && (!state.board[pos] || state.board[pos]->color() != m_color)) m_allowed.insert(pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Bishop::reinit(const State& state) {
|
|
||||||
if (m_pos == INVALID_POS) return;
|
|
||||||
if (!m_update) return;
|
|
||||||
m_update = false;
|
|
||||||
m_allowed.clear();
|
|
||||||
auto pin = findPin(*this, state);
|
|
||||||
for (auto& direction : filter(pin, { NE, SE, SW, NW })) {
|
|
||||||
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Rook::reinit(const State& state) {
|
|
||||||
if (m_pos == INVALID_POS) return;
|
|
||||||
if (!m_update) return;
|
|
||||||
m_update = false;
|
|
||||||
m_allowed.clear();
|
|
||||||
auto pin = findPin(*this, state);
|
|
||||||
for (auto& direction : filter(pin, { N, E, S, W })) {
|
|
||||||
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Queen::reinit(const State& state) {
|
|
||||||
if (m_pos == INVALID_POS) return;
|
|
||||||
if (!m_update) return;
|
|
||||||
m_update = false;
|
|
||||||
m_allowed.clear();
|
|
||||||
auto pin = findPin(*this, state);
|
|
||||||
for (auto& direction : filter(pin, { N, NE, E, SE, S, SW, W, NW })) {
|
|
||||||
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void King::reinit(const State& state) {
|
|
||||||
if (m_pos == INVALID_POS) return;
|
|
||||||
if (!m_update) return;
|
|
||||||
m_update = false;
|
|
||||||
m_allowed.clear();
|
|
||||||
auto& enemyPieces = m_color ? state.whites : state.blacks;
|
|
||||||
auto& pawnAttackLeft = m_color ? SW : NW;
|
|
||||||
auto& pawnAttackRight = m_color ? SE : NE;
|
|
||||||
for (auto& direction : { N, NE, E, SE, S, SW, W, NW }) {
|
|
||||||
auto pos = makeStep(m_pos, direction);
|
|
||||||
bool accept = pos != INVALID_POS && !(state.board[pos] && state.board[pos]->color() == m_color);
|
|
||||||
if (accept) {
|
|
||||||
for (auto& p : enemyPieces) {
|
|
||||||
if (!p.movePattern(pos)) continue;
|
|
||||||
if (p.type() == Piece::Knight || p.type() == Piece::King) {
|
|
||||||
accept = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
else if (p.type() == Piece::Pawn) {
|
|
||||||
auto from = positions[pos];
|
|
||||||
auto to = p.coord();
|
|
||||||
Direction d {char(to[R] - from[R]), char(to[F] - from[F])};
|
|
||||||
if (d == pawnAttackLeft || d == pawnAttackRight) {
|
|
||||||
accept = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto from = positions[pos];
|
|
||||||
auto to = p.coord();
|
|
||||||
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
|
|
||||||
auto reached = traverse(pos, d, Find(state.board));
|
|
||||||
if (p.pos() == reached) {
|
|
||||||
accept = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (accept) m_allowed.insert(pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* Piece::name() const {
|
|
||||||
static_assert(pieceNames.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
|
|
||||||
return pieceNames[m_type];
|
|
||||||
}
|
|
||||||
|
|
||||||
char Piece::initial() const {
|
|
||||||
static_assert(blackShort.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
|
|
||||||
static_assert(whiteShort.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
|
|
||||||
return m_color ? blackShort[m_type] : whiteShort[m_type];
|
|
||||||
}
|
|
||||||
|
|
||||||
void Piece::invalidate() {
|
|
||||||
m_update = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
const char* Piece::coord() const {
|
|
||||||
if (m_pos == INVALID_POS) return "";
|
|
||||||
return positions[m_pos];
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Piece::canReach(char pos) const {
|
|
||||||
return movePattern(pos) && m_allowed.count(pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Piece::take() {
|
|
||||||
m_pos = INVALID_POS;
|
|
||||||
m_allowed = {};
|
|
||||||
}
|
|
||||||
|
|
||||||
State::State()
|
|
||||||
: blacks {
|
|
||||||
{Piece::Black, "a7"_P, {"a5"_P, "a6"_P} },
|
|
||||||
{Piece::Black, "b7"_P, {"b5"_P, "b6"_P} },
|
|
||||||
{Piece::Black, "c7"_P, {"c5"_P, "c6"_P} },
|
|
||||||
{Piece::Black, "d7"_P, {"d5"_P, "d6"_P} },
|
|
||||||
{Piece::Black, "e7"_P, {"e5"_P, "e6"_P} },
|
|
||||||
{Piece::Black, "f7"_P, {"f5"_P, "f6"_P} },
|
|
||||||
{Piece::Black, "g7"_P, {"g5"_P, "g6"_P} },
|
|
||||||
{Piece::Black, "h7"_P, {"h5"_P, "h6"_P} },
|
|
||||||
{Piece::Black, "a8"_P},
|
|
||||||
{Piece::Black, "b8"_P, {"a6"_P, "c6"_P} },
|
|
||||||
{Piece::Black, "c8"_P},
|
|
||||||
{Piece::Black, "d8"_P},
|
|
||||||
{Piece::Black, "e8"_P},
|
|
||||||
{Piece::Black, "f8"_P},
|
|
||||||
{Piece::Black, "g8"_P, {"f6"_P, "h6"_P} },
|
|
||||||
{Piece::Black, "h8"_P},
|
|
||||||
}
|
|
||||||
, whites {
|
|
||||||
{Piece::White, "a2"_P, {"a3"_P, "a4"_P} },
|
|
||||||
{Piece::White, "b2"_P, {"b3"_P, "b4"_P} },
|
|
||||||
{Piece::White, "c2"_P, {"c3"_P, "c4"_P} },
|
|
||||||
{Piece::White, "d2"_P, {"d3"_P, "d4"_P} },
|
|
||||||
{Piece::White, "e2"_P, {"e3"_P, "e4"_P} },
|
|
||||||
{Piece::White, "f2"_P, {"f3"_P, "f4"_P} },
|
|
||||||
{Piece::White, "g2"_P, {"g3"_P, "g4"_P} },
|
|
||||||
{Piece::White, "h2"_P, {"h3"_P, "h4"_P} },
|
|
||||||
{Piece::White, "a1"_P},
|
|
||||||
{Piece::White, "b1"_P, {"a3"_P, "c3"_P} },
|
|
||||||
{Piece::White, "c1"_P},
|
|
||||||
{Piece::White, "d1"_P},
|
|
||||||
{Piece::White, "e1"_P},
|
|
||||||
{Piece::White, "f1"_P},
|
|
||||||
{Piece::White, "g1"_P, {"f3"_P, "h3"_P} },
|
|
||||||
{Piece::White, "h1"_P},
|
|
||||||
}
|
|
||||||
, board {{
|
|
||||||
&whites[ 8], &whites[ 9], &whites[10], &whites[11], &whites[12], &whites[13], &whites[14], &whites[15],
|
|
||||||
&whites[ 0], &whites[ 1], &whites[ 2], &whites[ 3], &whites[ 4], &whites[ 5], &whites[ 6], &whites[ 7],
|
|
||||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
|
||||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
|
||||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
|
||||||
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
|
||||||
&blacks[ 0], &blacks[ 1], &blacks[ 2], &blacks[ 3], &blacks[ 4], &blacks[ 5], &blacks[ 6], &blacks[ 7],
|
|
||||||
&blacks[ 8], &blacks[ 9], &blacks[10], &blacks[11], &blacks[12], &blacks[13], &blacks[14], &blacks[15],
|
|
||||||
}}
|
|
||||||
{}
|
|
||||||
|
|
||||||
Chessboard::Chessboard()
|
|
||||||
: m_state(new State())
|
|
||||||
{
|
|
||||||
setGrammar();
|
|
||||||
}
|
|
||||||
|
|
||||||
Chessboard::~Chessboard() = default;
|
|
||||||
|
|
||||||
void Chessboard::setPrompt(const std::string& prompt) {
|
|
||||||
m_prompt = prompt;
|
|
||||||
setGrammar();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chessboard::setGrammar() {
|
|
||||||
m_grammar.clear();
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
if (m_prompt.empty()) {
|
|
||||||
result += "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n";
|
|
||||||
//result += "move ::= \" \" frompos \" \" \"to \"? topos\n";
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// result += "move ::= prompt \" \" ((piece | frompos) \" \" \"to \"?)? topos\n"
|
|
||||||
result += "move ::= prompt \" \" frompos \" \" \"to \"? topos\n"
|
|
||||||
"prompt ::= \" " + m_prompt + "\"\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::set<Piece::Types> pieceTypes;
|
|
||||||
std::set<char> from_pos;
|
|
||||||
std::set<char> to_pos;
|
|
||||||
auto& pieces = m_moveCounter % 2 ? m_state->blacks : m_state->whites;
|
|
||||||
std::set<size_t> flags;
|
|
||||||
for (auto& p : pieces) {
|
|
||||||
if (p.allowed().empty()) continue;
|
|
||||||
bool addPiece = false;
|
|
||||||
if (!m_inCheck || p.type() == Piece::King) {
|
|
||||||
to_pos.insert(p.allowed().begin(), p.allowed().end());
|
|
||||||
addPiece = !p.allowed().empty();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (auto move : p.allowed()) {
|
|
||||||
if (m_allowedInCheck.count(move)) {
|
|
||||||
to_pos.insert(move);
|
|
||||||
addPiece = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (addPiece) {
|
|
||||||
pieceTypes.insert(p.type());
|
|
||||||
from_pos.insert(p.pos());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (pieceTypes.empty()) return;
|
|
||||||
|
|
||||||
result += "piece ::= (";
|
|
||||||
for (auto& p : pieceTypes) result += " \"" + std::string(pieceNames[p]) + "\" |";
|
|
||||||
result.pop_back();
|
|
||||||
result += ")\n\n";
|
|
||||||
|
|
||||||
result += "frompos ::= (";
|
|
||||||
for (auto& p : from_pos) result += " \"" + std::string(positions[p]) + "\" |";
|
|
||||||
result.pop_back();
|
|
||||||
result += ")\n";
|
|
||||||
|
|
||||||
result += "topos ::= (";
|
|
||||||
for (auto& p : to_pos) result += " \"" + std::string(positions[p]) + "\" |";
|
|
||||||
result.pop_back();
|
|
||||||
result += ")\n";
|
|
||||||
|
|
||||||
m_grammar = std::move(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Chessboard::stringifyBoard() {
|
|
||||||
std::string result;
|
|
||||||
result.reserve(16 + 2 * 64 + 16);
|
|
||||||
for (char rank = 'a'; rank <= 'h'; ++rank) {
|
|
||||||
result.push_back(rank);
|
|
||||||
result.push_back(' ');
|
|
||||||
}
|
|
||||||
result.back() = '\n';
|
|
||||||
for (int i = 7; i >= 0; --i) {
|
|
||||||
for (int j = 0; j < 8; ++j) {
|
|
||||||
auto p = m_state->board[i * 8 + j];
|
|
||||||
if (p) result.push_back(p->initial());
|
|
||||||
else result.push_back((i + j) % 2 ? '.' : '*');
|
|
||||||
result.push_back(' ');
|
|
||||||
}
|
|
||||||
result.push_back('0' + i + 1);
|
|
||||||
result.push_back('\n');
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Chessboard::process(const std::string& command) {
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
||||||
auto color = Piece::Colors(m_moveCounter % 2);
|
|
||||||
Piece* piece = nullptr;
|
|
||||||
auto pos_to = INVALID_POS;
|
|
||||||
if (!parseCommand(command, piece, pos_to)) return "";
|
|
||||||
|
|
||||||
auto pos_from = piece->pos();
|
|
||||||
|
|
||||||
if (!move(*piece, pos_to)) return "";
|
|
||||||
|
|
||||||
flagUpdates(pos_from, pos_to);
|
|
||||||
|
|
||||||
detectChecks();
|
|
||||||
|
|
||||||
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
|
|
||||||
for (auto& p : enemyPieces) p.reinit(*m_state); // only enemy moves needed next
|
|
||||||
|
|
||||||
std::string result = {positions[pos_from][R], positions[pos_from][F], '-', positions[pos_to][R], positions[pos_to][F]};
|
|
||||||
++m_moveCounter;
|
|
||||||
setGrammar();
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
||||||
auto t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
|
||||||
fprintf(stdout, "%s: Move '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", result.data(), "\033[0m", (int) t_ms);
|
|
||||||
if (m_grammar.empty()) result.push_back('#');
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Chessboard::parseCommand(const std::string& command, Piece*& piece, char& pos_to) {
|
|
||||||
auto color = Piece::Colors(m_moveCounter % 2);
|
|
||||||
fprintf(stdout, "%s: Command to %s: '%s%.*s%s'\n", __func__, (color ? "Black" : "White"), "\033[1m", int(command.size()), command.data(), "\033[0m");
|
|
||||||
|
|
||||||
if (command.empty()) return false;
|
|
||||||
auto tokens = split(command, ' ');
|
|
||||||
auto pos_from = INVALID_POS;
|
|
||||||
auto type = Piece::Types::NUM_PIECES;
|
|
||||||
if (tokens.size() == 1) {
|
|
||||||
type = Piece::Types::Pawn;
|
|
||||||
pos_to = strToPos(tokens.front());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
pos_from = strToPos(tokens.front());
|
|
||||||
if (pos_from == INVALID_POS) type = Piece::Types(strToType(tokens.front()));
|
|
||||||
pos_to = strToPos(tokens.back());
|
|
||||||
}
|
|
||||||
if (pos_to == INVALID_POS) return false;
|
|
||||||
if (pos_from == INVALID_POS) {
|
|
||||||
if (type == Piece::Types::NUM_PIECES) return false;
|
|
||||||
auto& pieces = color ? m_state->blacks : m_state->whites;
|
|
||||||
for (auto& p : pieces) {
|
|
||||||
if (p.type() == type && p.canReach(pos_to)) {
|
|
||||||
pos_from = p.pos();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (pos_from == INVALID_POS) return false;
|
|
||||||
if (m_state->board[pos_from] == nullptr) return false;
|
|
||||||
piece = m_state->board[pos_from];
|
|
||||||
if (piece->color() != color) return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chessboard::flagUpdates(char pos_from, char pos_to) {
|
|
||||||
auto color = Piece::Colors(m_moveCounter % 2);
|
|
||||||
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
|
|
||||||
auto& ownPieces = color ? m_state->blacks : m_state->whites;
|
|
||||||
for (auto& p : enemyPieces) {
|
|
||||||
if (p.movePattern(pos_to) || p.movePattern(pos_from)) {
|
|
||||||
updatePins(p);
|
|
||||||
p.invalidate();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& p : ownPieces) {
|
|
||||||
if (p.movePattern(pos_to) || p.movePattern(pos_from)) {
|
|
||||||
updatePins(p);
|
|
||||||
p.invalidate();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chessboard::updatePins(Piece& piece) {
|
|
||||||
if (piece.type() == Piece::Pawn || piece.type() == Piece::Knight || piece.type() == Piece::King) return;
|
|
||||||
auto& enemyPieces = piece.color() ? m_state->whites : m_state->blacks;
|
|
||||||
auto& enemyPins = piece.color() ? m_state->whitePins : m_state->blackPins;
|
|
||||||
auto& king = enemyPieces.k;
|
|
||||||
auto it = std::find_if(enemyPins.begin(), enemyPins.end(), [&] (const Pin& pin) { return pin.pinner == &piece; });
|
|
||||||
if (it != enemyPins.end()) {
|
|
||||||
it->pinned->invalidate();
|
|
||||||
enemyPins.erase(it);
|
|
||||||
}
|
|
||||||
if (piece.movePattern(king.pos())) {
|
|
||||||
auto to = positions[king.pos()];
|
|
||||||
auto from = piece.coord();
|
|
||||||
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
|
|
||||||
|
|
||||||
auto reached = traverse(piece.pos(), d, Find(m_state->board));
|
|
||||||
auto foundPiece = m_state->board[reached];
|
|
||||||
if (&king == foundPiece) {
|
|
||||||
// check
|
|
||||||
king.invalidate();
|
|
||||||
}
|
|
||||||
else if (foundPiece && foundPiece->color() != piece.color()) {
|
|
||||||
reached = traverse(reached, d, Find(m_state->board));
|
|
||||||
if (&king == m_state->board[reached]) {
|
|
||||||
enemyPins.push_back({d, &piece, foundPiece});
|
|
||||||
foundPiece->invalidate();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Chessboard::detectChecks() {
|
|
||||||
auto color = Piece::Colors(m_moveCounter % 2);
|
|
||||||
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
|
|
||||||
auto& ownPieces = color ? m_state->blacks : m_state->whites;
|
|
||||||
auto& king = enemyPieces.k;
|
|
||||||
auto& pawnAttackLeft = color ? SW : NW;
|
|
||||||
auto& pawnAttackRight = color ? SE : NE;
|
|
||||||
for (auto& p : ownPieces) {
|
|
||||||
if (!p.movePattern(king.pos())) continue;
|
|
||||||
auto to = positions[king.pos()];
|
|
||||||
auto from = p.coord();
|
|
||||||
|
|
||||||
if (p.type() == Piece::Knight) {
|
|
||||||
if (!m_inCheck) {
|
|
||||||
m_allowedInCheck = { p.pos() };
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
m_allowedInCheck.clear();
|
|
||||||
}
|
|
||||||
m_inCheck = true;
|
|
||||||
}
|
|
||||||
else if (p.type() == Piece::Pawn) {
|
|
||||||
Direction d {char(to[R] - from[R]), char(to[F] - from[F])};
|
|
||||||
if (d == pawnAttackLeft || d == pawnAttackRight) {
|
|
||||||
if (!m_inCheck) {
|
|
||||||
m_allowedInCheck = { p.pos() };
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
m_allowedInCheck.clear();
|
|
||||||
}
|
|
||||||
m_inCheck = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
|
|
||||||
std::set<char> tmp;
|
|
||||||
auto pos = traverse(p.pos(), d, Add(m_state->board, tmp, king.color()));
|
|
||||||
if (pos == king.pos()) {
|
|
||||||
tmp.insert(p.pos());
|
|
||||||
if (!m_inCheck) {
|
|
||||||
m_allowedInCheck = std::move(tmp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
m_allowedInCheck.clear();
|
|
||||||
}
|
|
||||||
m_inCheck = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Chessboard::move(Piece& piece, char pos_to) {
|
|
||||||
auto& allowed = piece.allowed();
|
|
||||||
|
|
||||||
if (allowed.count(pos_to) == 0 || (m_inCheck && piece.type() != Piece::King && m_allowedInCheck.count(pos_to) == 0)) return false;
|
|
||||||
if (m_state->board[pos_to] && m_state->board[pos_to]->color() == piece.color()) return false;
|
|
||||||
if (m_state->board[pos_to]) m_state->board[pos_to]->take();
|
|
||||||
m_state->board[piece.pos()] = nullptr;
|
|
||||||
m_state->board[pos_to] = &piece;
|
|
||||||
piece.setPos(pos_to);
|
|
||||||
|
|
||||||
m_inCheck = false;
|
|
||||||
m_allowedInCheck.clear();
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
@ -1,33 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include <string>
|
|
||||||
#include <set>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
// just basic validation
|
|
||||||
// fixme: missing en passant, castling, promotion, etc.
|
|
||||||
struct State;
|
|
||||||
class Piece;
|
|
||||||
class Chessboard {
|
|
||||||
public:
|
|
||||||
Chessboard();
|
|
||||||
~Chessboard();
|
|
||||||
std::string process(const std::string& command);
|
|
||||||
std::string stringifyBoard();
|
|
||||||
const std::string& grammar() { return m_grammar; }
|
|
||||||
const std::string& prompt() { return m_prompt; }
|
|
||||||
void setPrompt(const std::string& prompt);
|
|
||||||
private:
|
|
||||||
bool parseCommand(const std::string& command, Piece*& piece, char& pos_to);
|
|
||||||
bool move(Piece& piece, char pos);
|
|
||||||
void flagUpdates(char pos_from, char pos_to);
|
|
||||||
void updatePins(Piece& piece);
|
|
||||||
void detectChecks();
|
|
||||||
void setGrammar();
|
|
||||||
|
|
||||||
std::unique_ptr<State> m_state;
|
|
||||||
std::set<char> m_allowedInCheck;
|
|
||||||
bool m_inCheck = false;
|
|
||||||
int m_moveCounter = 0;
|
|
||||||
std::string m_grammar;
|
|
||||||
std::string m_prompt;
|
|
||||||
};
|
|
@ -1,193 +0,0 @@
|
|||||||
#include "WChess.h"
|
|
||||||
#include "Chessboard.h"
|
|
||||||
#include "grammar-parser.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include <thread>
|
|
||||||
|
|
||||||
WChess::WChess(whisper_context * ctx,
|
|
||||||
const whisper_full_params & wparams,
|
|
||||||
callbacks cb,
|
|
||||||
settings s)
|
|
||||||
: m_ctx(ctx)
|
|
||||||
, m_wparams(wparams)
|
|
||||||
, m_cb(cb)
|
|
||||||
, m_settings(s)
|
|
||||||
, m_board(new Chessboard())
|
|
||||||
{}
|
|
||||||
|
|
||||||
WChess::~WChess() = default;
|
|
||||||
|
|
||||||
void WChess::set_move(const std::string& moves, float prob) const {
|
|
||||||
if (m_cb.set_move) (*m_cb.set_move)(moves, prob);
|
|
||||||
}
|
|
||||||
|
|
||||||
void WChess::set_grammar(const std::string& grammar) const {
|
|
||||||
if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool WChess::get_audio(std::vector<float>& pcmf32) const {
|
|
||||||
if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string WChess::stringify_board() const {
|
|
||||||
return m_board->stringifyBoard();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string WChess::get_grammar() const {
|
|
||||||
return m_board->grammar();
|
|
||||||
}
|
|
||||||
|
|
||||||
void WChess::run() {
|
|
||||||
bool have_prompt = true;
|
|
||||||
bool ask_prompt = !have_prompt;
|
|
||||||
|
|
||||||
float logprob_min = 0.0f;
|
|
||||||
|
|
||||||
float logprob_sum = 0.0f;
|
|
||||||
|
|
||||||
int n_tokens = 0;
|
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
|
||||||
std::vector<float> pcmf32_prompt;
|
|
||||||
|
|
||||||
const std::string k_prompt = have_prompt ? "" : "rook to d4, f3";
|
|
||||||
int64_t t_ms = 0;
|
|
||||||
|
|
||||||
if (ask_prompt) {
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
|
|
||||||
ask_prompt = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (get_audio(pcmf32_cur)) {
|
|
||||||
if (!pcmf32_cur.empty()) {
|
|
||||||
// fprintf(stdout, "%s: Processing ...\n", __func__);
|
|
||||||
|
|
||||||
if (!have_prompt) {
|
|
||||||
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
|
||||||
|
|
||||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
|
||||||
|
|
||||||
const float sim = similarity(txt, k_prompt);
|
|
||||||
|
|
||||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
|
||||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
|
||||||
ask_prompt = true;
|
|
||||||
} else {
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
|
||||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
|
|
||||||
// save the audio for the prompt
|
|
||||||
pcmf32_prompt = pcmf32_cur;
|
|
||||||
have_prompt = true;
|
|
||||||
m_board->setPrompt(k_prompt);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
|
||||||
constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
|
|
||||||
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
|
|
||||||
|
|
||||||
// fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str());
|
|
||||||
|
|
||||||
auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str());
|
|
||||||
auto grammar_rules = grammar_parsed.c_rules();
|
|
||||||
|
|
||||||
m_wparams.grammar_rules = grammar_rules.data();
|
|
||||||
m_wparams.n_grammar_rules = grammar_rules.size();
|
|
||||||
|
|
||||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
|
|
||||||
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
|
||||||
|
|
||||||
const float p = 100.0f * std::exp(logprob_min);
|
|
||||||
|
|
||||||
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
|
||||||
|
|
||||||
// find the prompt in the text
|
|
||||||
float best_sim = 0.0f;
|
|
||||||
size_t best_len = 0;
|
|
||||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
|
||||||
const auto prompt = txt.substr(0, n);
|
|
||||||
|
|
||||||
const float sim = similarity(prompt, k_prompt);
|
|
||||||
|
|
||||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
|
||||||
|
|
||||||
if (sim > best_sim) {
|
|
||||||
best_sim = sim;
|
|
||||||
best_len = n;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
|
||||||
std::string command = ::trim(txt.substr(best_len));
|
|
||||||
|
|
||||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
|
|
||||||
if (!command.empty()) {
|
|
||||||
set_move(m_board->process(command), p);
|
|
||||||
set_grammar(m_board->grammar());
|
|
||||||
}
|
|
||||||
if (m_board->grammar().empty()) {
|
|
||||||
fprintf(stdout, "%s: No more moves possible\n", __func__);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ask_prompt) {
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
|
|
||||||
ask_prompt = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string WChess::transcribe(
|
|
||||||
const std::vector<float> & pcmf32,
|
|
||||||
float & logprob_min,
|
|
||||||
float & logprob_sum,
|
|
||||||
int & n_tokens,
|
|
||||||
int64_t & t_ms) {
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
||||||
|
|
||||||
logprob_min = 0.0f;
|
|
||||||
logprob_sum = 0.0f;
|
|
||||||
n_tokens = 0;
|
|
||||||
t_ms = 0;
|
|
||||||
|
|
||||||
if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string result;
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(m_ctx);
|
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
|
||||||
const char * text = whisper_full_get_segment_text(m_ctx, i);
|
|
||||||
|
|
||||||
result += text;
|
|
||||||
|
|
||||||
const int n = whisper_full_n_tokens(m_ctx, i);
|
|
||||||
for (int j = 0; j < n; ++j) {
|
|
||||||
const auto token = whisper_full_get_token_data(m_ctx, i, j);
|
|
||||||
|
|
||||||
if(token.plog > 0.0f) return {};
|
|
||||||
logprob_min = std::min(logprob_min, token.plog);
|
|
||||||
logprob_sum += token.plog;
|
|
||||||
++n_tokens;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
||||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
@ -1,63 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "whisper.h"
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
class Chessboard;
|
|
||||||
|
|
||||||
class WChess {
|
|
||||||
public:
|
|
||||||
using CheckRunningCb = bool (*)();
|
|
||||||
using GetAudioCb = bool (*)(std::vector<float> &);
|
|
||||||
using SetMovesCb = void (*)(const std::string &, float);
|
|
||||||
using SetGrammarCb = void (*)(const std::string &);
|
|
||||||
using ClearAudioCb = void (*)();
|
|
||||||
|
|
||||||
struct callbacks {
|
|
||||||
GetAudioCb get_audio = nullptr;
|
|
||||||
SetMovesCb set_move = nullptr;
|
|
||||||
SetGrammarCb set_grammar = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct settings {
|
|
||||||
int32_t vad_ms = 2000;
|
|
||||||
int32_t prompt_ms = 5000;
|
|
||||||
int32_t command_ms = 4000;
|
|
||||||
float vad_thold = 0.2f;
|
|
||||||
float freq_thold = 100.0f;
|
|
||||||
bool print_energy = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
WChess(
|
|
||||||
whisper_context * ctx,
|
|
||||||
const whisper_full_params & wparams,
|
|
||||||
callbacks cb,
|
|
||||||
settings s
|
|
||||||
);
|
|
||||||
~WChess();
|
|
||||||
|
|
||||||
void run();
|
|
||||||
|
|
||||||
std::string stringify_board() const;
|
|
||||||
|
|
||||||
std::string get_grammar() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool get_audio(std::vector<float>& pcmf32) const;
|
|
||||||
void set_move(const std::string& moves, float prob) const;
|
|
||||||
void set_grammar(const std::string& grammar) const;
|
|
||||||
|
|
||||||
std::string transcribe(
|
|
||||||
const std::vector<float> & pcmf32,
|
|
||||||
float & logprob_min,
|
|
||||||
float & logprob_sum,
|
|
||||||
int & n_tokens,
|
|
||||||
int64_t & t_ms);
|
|
||||||
|
|
||||||
whisper_context * m_ctx;
|
|
||||||
whisper_full_params m_wparams;
|
|
||||||
const callbacks m_cb;
|
|
||||||
const settings m_settings;
|
|
||||||
std::unique_ptr<Chessboard> m_board;
|
|
||||||
};
|
|
@ -1,117 +0,0 @@
|
|||||||
#include "Chessboard.h"
|
|
||||||
|
|
||||||
#define ASSERT(x) \
|
|
||||||
do { \
|
|
||||||
if (!(x)) { \
|
|
||||||
fprintf(stderr, "ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
||||||
fflush(stderr); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
|
|
||||||
int main() {
|
|
||||||
{
|
|
||||||
Chessboard chess;
|
|
||||||
|
|
||||||
ASSERT(chess.process("pawn to d4") == "d2-d4");
|
|
||||||
ASSERT(chess.process("e5") == "e7-e5");
|
|
||||||
ASSERT(chess.process("c1 h6") == "c1-h6");
|
|
||||||
ASSERT(chess.process("queen h4") == "d8-h4");
|
|
||||||
ASSERT(chess.process("bishop to g5") == "h6-g5");
|
|
||||||
ASSERT(chess.process("bishop to b4") == "f8-b4");
|
|
||||||
ASSERT(chess.process("c4") == "");
|
|
||||||
ASSERT(chess.process("knight c3") == "b1-c3");
|
|
||||||
ASSERT(chess.process("knight c6") == "b8-c6");
|
|
||||||
ASSERT(chess.process("f3") == "");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Chessboard chess;
|
|
||||||
|
|
||||||
ASSERT(chess.process("d4") == "d2-d4");
|
|
||||||
ASSERT(chess.process("e5") == "e7-e5");
|
|
||||||
ASSERT(chess.process("e4") == "e2-e4");
|
|
||||||
ASSERT(chess.process("queen h4") == "d8-h4");
|
|
||||||
ASSERT(chess.process("queen h5") == "d1-h5");
|
|
||||||
ASSERT(chess.process("f5") == "");
|
|
||||||
ASSERT(chess.process("g6") == "g7-g6");
|
|
||||||
ASSERT(chess.process("knight e2") == "g1-e2");
|
|
||||||
ASSERT(chess.process("f5") == "f7-f5");
|
|
||||||
ASSERT(chess.process("knight g3") == "e2-g3");
|
|
||||||
ASSERT(chess.process("g5") == "");
|
|
||||||
ASSERT(chess.process("king e7") == "e8-e7");
|
|
||||||
ASSERT(chess.process("f4") == "f2-f4");
|
|
||||||
ASSERT(chess.process("g5") == "g6-g5");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Chessboard chess;
|
|
||||||
|
|
||||||
ASSERT(chess.process("e4") == "e2-e4");
|
|
||||||
ASSERT(chess.process("c5") == "c7-c5");
|
|
||||||
ASSERT(chess.process("e5") == "e4-e5");
|
|
||||||
ASSERT(chess.process("c4") == "c5-c4");
|
|
||||||
ASSERT(chess.process("e6") == "e5-e6");
|
|
||||||
ASSERT(chess.process("c3") == "c4-c3");
|
|
||||||
ASSERT(chess.process("e7") == "");
|
|
||||||
ASSERT(chess.process("f7") == "e6-f7");
|
|
||||||
ASSERT(chess.process("d2") == "");
|
|
||||||
ASSERT(chess.process("king to f7") == "e8-f7");
|
|
||||||
ASSERT(chess.process("f4") == "f2-f4");
|
|
||||||
ASSERT(chess.process("d2") == "c3-d2");
|
|
||||||
ASSERT(chess.process("f5") == "");
|
|
||||||
ASSERT(chess.process("king to e2") == "e1-e2");
|
|
||||||
ASSERT(chess.process("king to g6") == "f7-g6");
|
|
||||||
ASSERT(chess.process("f5") == "f4-f5");
|
|
||||||
ASSERT(chess.process("e6") == "");
|
|
||||||
ASSERT(chess.process("king to h5") == "g6-h5");
|
|
||||||
ASSERT(chess.process("g4") == "g2-g4");
|
|
||||||
ASSERT(chess.process("king to g5") == "h5-g5");
|
|
||||||
ASSERT(chess.process("h4") == "h2-h4");
|
|
||||||
ASSERT(chess.process("king to h5") == "");
|
|
||||||
ASSERT(chess.process("king to g6") == "");
|
|
||||||
ASSERT(chess.process("king to h6") == "g5-h6");
|
|
||||||
ASSERT(chess.process("bishop to d2") == "c1-d2");
|
|
||||||
ASSERT(chess.process("king to g5") == "");
|
|
||||||
ASSERT(chess.process("g5") == "g7-g5");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Chessboard chess;
|
|
||||||
ASSERT(chess.process("f4") == "f2-f4");
|
|
||||||
ASSERT(chess.process("e5") == "e7-e5");
|
|
||||||
ASSERT(chess.process("g4") == "g2-g4");
|
|
||||||
ASSERT(chess.process("queen to h4") == "d8-h4#");
|
|
||||||
ASSERT(chess.process("knight f3") == "");
|
|
||||||
ASSERT(chess.grammar().empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Chessboard chess;
|
|
||||||
ASSERT(chess.process("f4") == "f2-f4");
|
|
||||||
ASSERT(chess.process("e5") == "e7-e5");
|
|
||||||
ASSERT(chess.process("g4") == "g2-g4");
|
|
||||||
ASSERT(chess.process("d5") == "d7-d5");
|
|
||||||
ASSERT(chess.process("g1 f3") == "g1-f3");
|
|
||||||
ASSERT(chess.process("queen to h4") == "d8-h4");
|
|
||||||
ASSERT(!chess.grammar().empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Chessboard chess;
|
|
||||||
ASSERT(chess.process("knight c3") == "b1-c3");
|
|
||||||
ASSERT(chess.process("knight c6") == "b8-c6");
|
|
||||||
ASSERT(chess.process("knight b5") == "c3-b5");
|
|
||||||
ASSERT(chess.process("knight f6") == "g8-f6");
|
|
||||||
ASSERT(chess.process("knight d6") == "b5-d6");
|
|
||||||
ASSERT(chess.process("knight d4") == "");
|
|
||||||
ASSERT(chess.process("d6") == "c7-d6");
|
|
||||||
ASSERT(chess.process("e4") == "e2-e4");
|
|
||||||
ASSERT(chess.process("knight d4") == "c6-d4");
|
|
||||||
ASSERT(chess.process("d3") == "d2-d3");
|
|
||||||
ASSERT(chess.process("knight e4") == "f6-e4");
|
|
||||||
ASSERT(chess.process("king to e2") == "");
|
|
||||||
ASSERT(chess.process("king to d2") == "");
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,8 +0,0 @@
|
|||||||
if (WHISPER_SDL2)
|
|
||||||
set(TARGET wchess)
|
|
||||||
add_executable(${TARGET} wchess.cmd.cpp)
|
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE wchess-core common-sdl ${CMAKE_THREAD_LIBS_INIT})
|
|
||||||
endif ()
|
|
@ -1,247 +0,0 @@
|
|||||||
// Command line voice assisted chess
|
|
||||||
//
|
|
||||||
// Speak chess move commands to the microphone.
|
|
||||||
// The moves will translated to chessboard positions.
|
|
||||||
//
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "WChess.h"
|
|
||||||
#include "common-sdl.h"
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <thread>
|
|
||||||
|
|
||||||
// command-line parameters
|
|
||||||
struct whisper_params {
|
|
||||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
|
||||||
int32_t prompt_ms = 5000;
|
|
||||||
int32_t command_ms = 8000;
|
|
||||||
int32_t capture_id = -1;
|
|
||||||
int32_t max_tokens = 32;
|
|
||||||
int32_t audio_ctx = 0;
|
|
||||||
|
|
||||||
float vad_thold = 0.6f;
|
|
||||||
float freq_thold = 100.0f;
|
|
||||||
|
|
||||||
float grammar_penalty = 100.0f;
|
|
||||||
|
|
||||||
bool speed_up = false;
|
|
||||||
bool translate = false;
|
|
||||||
bool print_special = false;
|
|
||||||
bool print_energy = false;
|
|
||||||
bool no_timestamps = true;
|
|
||||||
bool use_gpu = true;
|
|
||||||
|
|
||||||
std::string language = "en";
|
|
||||||
std::string model = "models/ggml-base.en.bin";
|
|
||||||
std::string fname_out;
|
|
||||||
std::string commands;
|
|
||||||
std::string prompt;
|
|
||||||
std::string context;
|
|
||||||
std::string grammar;
|
|
||||||
};
|
|
||||||
|
|
||||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "options:\n");
|
|
||||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
|
||||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
|
||||||
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
|
||||||
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
|
||||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
|
||||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
|
||||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
|
||||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
|
||||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
|
||||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
|
||||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
|
||||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
|
||||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
|
||||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
||||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
|
||||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
|
||||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
|
||||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
|
||||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
|
||||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
||||||
for (int i = 1; i < argc; i++) {
|
|
||||||
std::string arg = argv[i];
|
|
||||||
|
|
||||||
if (arg == "-h" || arg == "--help") {
|
|
||||||
whisper_print_usage(argc, argv, params);
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
|
||||||
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
|
|
||||||
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
|
|
||||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
|
||||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
|
||||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
|
||||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
|
||||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
|
||||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
|
||||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
|
||||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
|
||||||
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
|
||||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
||||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
|
||||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
|
||||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
||||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
|
||||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
|
||||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
|
||||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
||||||
else {
|
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
||||||
whisper_print_usage(argc, argv, params);
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<WChess> g_wchess;
|
|
||||||
int g_moveCount = 0;
|
|
||||||
void set_move(const std::string & move, float) {
|
|
||||||
if (!move.empty()) {
|
|
||||||
g_moveCount++;
|
|
||||||
fprintf(stdout, "Move: %s\n\n", move.c_str());
|
|
||||||
}
|
|
||||||
else fprintf(stdout, "Move rejected\n\n");
|
|
||||||
fprintf(stdout, "%s\n", g_wchess->stringify_board().c_str());
|
|
||||||
fprintf(stdout, "%s\n", g_moveCount ? "White's turn" : "Black's turn");
|
|
||||||
}
|
|
||||||
|
|
||||||
audio_async g_audio(30*1000);
|
|
||||||
bool g_listening = false;
|
|
||||||
std::vector<float> g_pcmf32;
|
|
||||||
|
|
||||||
bool read_input() {
|
|
||||||
std::string input;
|
|
||||||
while (true) {
|
|
||||||
fprintf(stdout, "[(l)isten/(p)ause/(q)uit]: ");
|
|
||||||
std::cin >> input;
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
if (input[0] == 'q') {
|
|
||||||
fprintf(stdout, "Quitting\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (input[0] == 'l') {
|
|
||||||
if (!g_listening) {
|
|
||||||
fprintf(stdout, "Listening\n");
|
|
||||||
g_listening = true;
|
|
||||||
g_pcmf32.clear();
|
|
||||||
g_audio.resume();
|
|
||||||
g_audio.clear();
|
|
||||||
}
|
|
||||||
else fprintf(stdout, "Still listening\n");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (g_listening) {
|
|
||||||
g_listening = false;
|
|
||||||
g_audio.get(0, g_pcmf32);
|
|
||||||
g_audio.pause();
|
|
||||||
fprintf(stdout, "Processing\n");
|
|
||||||
}
|
|
||||||
else fprintf(stdout, "Not listening\n");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool get_audio(std::vector<float> & pcmf32_cur) {
|
|
||||||
if (!read_input()) return false;
|
|
||||||
if (!g_pcmf32.empty()) pcmf32_cur = std::move(g_pcmf32);
|
|
||||||
else pcmf32_cur.clear();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
|
||||||
whisper_params params;
|
|
||||||
|
|
||||||
if (whisper_params_parse(argc, argv, params) == false) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
|
||||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
|
||||||
whisper_print_usage(argc, argv, params);
|
|
||||||
exit(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// whisper init
|
|
||||||
|
|
||||||
struct whisper_context_params cparams;
|
|
||||||
cparams.use_gpu = params.use_gpu;
|
|
||||||
|
|
||||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
|
||||||
if (!ctx) {
|
|
||||||
fprintf(stderr, "%s: whisper_init_from_file_with_params() failed!\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// init audio
|
|
||||||
|
|
||||||
if (!g_audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
|
||||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
|
||||||
wparams.offset_ms = 0;
|
|
||||||
wparams.translate = false;
|
|
||||||
wparams.no_context = true;
|
|
||||||
wparams.single_segment = true;
|
|
||||||
wparams.print_realtime = false;
|
|
||||||
wparams.print_progress = false;
|
|
||||||
wparams.print_timestamps = true;
|
|
||||||
wparams.print_special = false;
|
|
||||||
wparams.no_timestamps = true;
|
|
||||||
|
|
||||||
wparams.max_tokens = 32;
|
|
||||||
wparams.audio_ctx = 768; // partial encoder context for better performance
|
|
||||||
|
|
||||||
wparams.temperature = 0.0f;
|
|
||||||
wparams.temperature_inc = 2.0f;
|
|
||||||
wparams.greedy.best_of = 1;
|
|
||||||
|
|
||||||
wparams.beam_search.beam_size = 1;
|
|
||||||
|
|
||||||
wparams.language = "en";
|
|
||||||
|
|
||||||
wparams.grammar_penalty = 100.0;
|
|
||||||
|
|
||||||
wparams.initial_prompt = params.context.data();
|
|
||||||
|
|
||||||
WChess::callbacks cb;
|
|
||||||
cb.get_audio = get_audio;
|
|
||||||
cb.set_move = set_move;
|
|
||||||
|
|
||||||
WChess::settings s;
|
|
||||||
s.vad_ms = 2000;
|
|
||||||
s.prompt_ms = params.prompt_ms;
|
|
||||||
s.command_ms = params.command_ms;
|
|
||||||
s.vad_thold = params.vad_thold;
|
|
||||||
s.freq_thold = params.freq_thold;
|
|
||||||
s.print_energy = params.print_energy;
|
|
||||||
|
|
||||||
g_wchess.reset(new WChess(ctx, wparams, cb, s));
|
|
||||||
set_move("start", 0);
|
|
||||||
g_wchess->run();
|
|
||||||
|
|
||||||
whisper_print_timings(ctx);
|
|
||||||
whisper_free(ctx);
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,51 +0,0 @@
|
|||||||
set(TARGET wchess.wasm)
|
|
||||||
|
|
||||||
add_executable(${TARGET}
|
|
||||||
wchess.wasm.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
|
||||||
|
|
||||||
target_link_libraries(${TARGET} PRIVATE
|
|
||||||
common
|
|
||||||
wchess-core
|
|
||||||
)
|
|
||||||
|
|
||||||
unset(EXTRA_FLAGS)
|
|
||||||
|
|
||||||
if (WHISPER_WASM_SINGLE_FILE)
|
|
||||||
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
|
|
||||||
message(STATUS "Embedding WASM inside chess.js")
|
|
||||||
|
|
||||||
add_custom_command(
|
|
||||||
TARGET ${TARGET} POST_BUILD
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E copy
|
|
||||||
${CMAKE_BINARY_DIR}/bin/${TARGET}.js
|
|
||||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/chess.js
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
|
||||||
--bind \
|
|
||||||
-s USE_PTHREADS=1 \
|
|
||||||
-s PTHREAD_POOL_SIZE=8 \
|
|
||||||
-s INITIAL_MEMORY=1024MB \
|
|
||||||
-s TOTAL_MEMORY=1024MB \
|
|
||||||
-s FORCE_FILESYSTEM=1 \
|
|
||||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
|
||||||
${EXTRA_FLAGS} \
|
|
||||||
")
|
|
||||||
|
|
||||||
|
|
||||||
add_custom_command(
|
|
||||||
TARGET ${TARGET} POST_BUILD
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/chessboardjs-1.0.0
|
|
||||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E copy
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/jquery-3.7.1.min.js
|
|
||||||
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/
|
|
||||||
)
|
|
||||||
|
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
|
|
||||||
configure_file(${CMAKE_SOURCE_DIR}/examples/helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/helpers.js @ONLY)
|
|
@ -1,54 +0,0 @@
|
|||||||
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
|
|
||||||
|
|
||||||
.clearfix-7da63 {
|
|
||||||
clear: both;
|
|
||||||
}
|
|
||||||
|
|
||||||
.board-b72b1 {
|
|
||||||
border: 2px solid #404040;
|
|
||||||
box-sizing: content-box;
|
|
||||||
}
|
|
||||||
|
|
||||||
.square-55d63 {
|
|
||||||
float: left;
|
|
||||||
position: relative;
|
|
||||||
|
|
||||||
/* disable any native browser highlighting */
|
|
||||||
-webkit-touch-callout: none;
|
|
||||||
-webkit-user-select: none;
|
|
||||||
-khtml-user-select: none;
|
|
||||||
-moz-user-select: none;
|
|
||||||
-ms-user-select: none;
|
|
||||||
user-select: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
.white-1e1d7 {
|
|
||||||
background-color: #f0d9b5;
|
|
||||||
color: #b58863;
|
|
||||||
}
|
|
||||||
|
|
||||||
.black-3c85d {
|
|
||||||
background-color: #b58863;
|
|
||||||
color: #f0d9b5;
|
|
||||||
}
|
|
||||||
|
|
||||||
.highlight1-32417, .highlight2-9c5d2 {
|
|
||||||
box-shadow: inset 0 0 3px 3px yellow;
|
|
||||||
}
|
|
||||||
|
|
||||||
.notation-322f9 {
|
|
||||||
cursor: default;
|
|
||||||
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
|
|
||||||
font-size: 14px;
|
|
||||||
position: absolute;
|
|
||||||
}
|
|
||||||
|
|
||||||
.alpha-d2270 {
|
|
||||||
bottom: 1px;
|
|
||||||
right: 3px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.numeric-fc462 {
|
|
||||||
top: 2px;
|
|
||||||
left: 2px;
|
|
||||||
}
|
|
@ -1,2 +0,0 @@
|
|||||||
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
|
|
||||||
.clearfix-7da63{clear:both}.board-b72b1{border:2px solid #404040;box-sizing:content-box}.square-55d63{float:left;position:relative;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.white-1e1d7{background-color:#f0d9b5;color:#b58863}.black-3c85d{background-color:#b58863;color:#f0d9b5}.highlight1-32417,.highlight2-9c5d2{box-shadow:inset 0 0 3px 3px #ff0}.notation-322f9{cursor:default;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:14px;position:absolute}.alpha-d2270{bottom:1px;right:3px}.numeric-fc462{top:2px;left:2px}
|
|
Before Width: | Height: | Size: 1.4 KiB |
Before Width: | Height: | Size: 2.9 KiB |
Before Width: | Height: | Size: 1.8 KiB |
Before Width: | Height: | Size: 777 B |
Before Width: | Height: | Size: 2.6 KiB |
Before Width: | Height: | Size: 748 B |
Before Width: | Height: | Size: 2.3 KiB |
Before Width: | Height: | Size: 2.8 KiB |
Before Width: | Height: | Size: 2.3 KiB |
Before Width: | Height: | Size: 1.5 KiB |
Before Width: | Height: | Size: 3.7 KiB |
Before Width: | Height: | Size: 1.1 KiB |
@ -1,32 +0,0 @@
|
|||||||
# chessboard.js Change Log
|
|
||||||
|
|
||||||
All notable changes to this project will be documented in this file.
|
|
||||||
|
|
||||||
## [1.0.0] - 2019-06-11
|
|
||||||
- Orientation methods now return current orientation. [Issue #64]
|
|
||||||
- Drop support for IE8
|
|
||||||
- Do not check for `window.JSON` (Error #1004)
|
|
||||||
- Rename `ChessBoard` to `Chessboard` (`ChessBoard` is still supported, however)
|
|
||||||
- id query selectors are now supported as the first argument to `Chessboard()`
|
|
||||||
- Remove Error #1002
|
|
||||||
- Format code according to [StandardJS]
|
|
||||||
- Bump minimum jQuery version to 1.8.3
|
|
||||||
- Throttle piece drag functions
|
|
||||||
|
|
||||||
## [0.3.0] - 2013-08-10
|
|
||||||
- Added `appearSpeed` animation config property
|
|
||||||
- Added `onSnapbackEnd` event
|
|
||||||
- Added `onMoveEnd` event
|
|
||||||
|
|
||||||
## [0.2.0] - 2013-08-05
|
|
||||||
- Added `onMouseoverSquare` and `onMouseoutSquare` events
|
|
||||||
- Added `onSnapEnd` event
|
|
||||||
- Added square code as CSS class on the squares
|
|
||||||
- Added [chess.js] integration examples
|
|
||||||
|
|
||||||
## [0.1.0] - 2013-05-21
|
|
||||||
- Initial release
|
|
||||||
|
|
||||||
[chess.js]:https://github.com/jhlywa/chess.js
|
|
||||||
[Issue #64]:https://github.com/oakmac/chessboardjs/issues/64
|
|
||||||
[StandardJS]:https://standardjs.com/
|
|
@ -1,20 +0,0 @@
|
|||||||
Copyright 2019 Chris Oakman
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining
|
|
||||||
a copy of this software and associated documentation files (the
|
|
||||||
"Software"), to deal in the Software without restriction, including
|
|
||||||
without limitation the rights to use, copy, modify, merge, publish,
|
|
||||||
distribute, sublicense, and/or sell copies of the Software, and to
|
|
||||||
permit persons to whom the Software is furnished to do so, subject to
|
|
||||||
the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be
|
|
||||||
included in all copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
||||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
||||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
||||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
||||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -1,82 +0,0 @@
|
|||||||
# chessboard.js
|
|
||||||
|
|
||||||
chessboard.js is a JavaScript chessboard component. It depends on [jQuery].
|
|
||||||
|
|
||||||
Please see [chessboardjs.com] for documentation and examples.
|
|
||||||
|
|
||||||
## What is chessboard.js?
|
|
||||||
|
|
||||||
chessboard.js is a JavaScript chessboard component with a flexible "just a
|
|
||||||
board" API that
|
|
||||||
|
|
||||||
chessboard.js is a standalone JavaScript Chess Board. It is designed to be "just
|
|
||||||
a board" and expose a powerful API so that it can be used in different ways.
|
|
||||||
Here's a non-exhaustive list of things you can do with chessboard.js:
|
|
||||||
|
|
||||||
- Use chessboard.js to show game positions alongside your expert commentary.
|
|
||||||
- Use chessboard.js to have a tactics website where users have to guess the best
|
|
||||||
move.
|
|
||||||
- Integrate chessboard.js and [chess.js] with a PGN database and allow people to
|
|
||||||
search and playback games (see [Example 5000])
|
|
||||||
- Build a chess server and have users play their games out using the
|
|
||||||
chessboard.js board.
|
|
||||||
|
|
||||||
chessboard.js is flexible enough to handle any of these situations with relative
|
|
||||||
ease.
|
|
||||||
|
|
||||||
## What can chessboard.js **not** do?
|
|
||||||
|
|
||||||
The scope of chessboard.js is limited to "just a board." This is intentional and
|
|
||||||
makes chessboard.js flexible for handling a multitude of chess-related problems.
|
|
||||||
|
|
||||||
This is a common source of confusion for new users. [remove?]
|
|
||||||
|
|
||||||
Specifically, chessboard.js does not understand anything about how the game of
|
|
||||||
chess is played: how a knight moves, who's turn is it, is White in check?, etc.
|
|
||||||
|
|
||||||
Fortunately, the powerful [chess.js] library deals with exactly this sort of
|
|
||||||
problem domain and plays nicely with chessboard.js's flexible API. Some examples
|
|
||||||
of chessboard.js combined with chess.js: 5000, 5001, 5002
|
|
||||||
|
|
||||||
Please see the powerful [chess.js] library for an API to deal with these sorts
|
|
||||||
of questions.
|
|
||||||
|
|
||||||
|
|
||||||
This logic is distinct from the logic of the board. Please see the powerful
|
|
||||||
[chess.js] library for this aspect of your application.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Here is a list of things that chessboard.js is **not**:
|
|
||||||
|
|
||||||
- A chess engine
|
|
||||||
- A legal move validator
|
|
||||||
- A PGN parser
|
|
||||||
|
|
||||||
chessboard.js is designed to work well with any of those things, but the idea
|
|
||||||
behind chessboard.js is that the logic that controls the board should be
|
|
||||||
independent of those other problems.
|
|
||||||
|
|
||||||
## Docs and Examples
|
|
||||||
|
|
||||||
- Docs - <http://chessboardjs.com/docs>
|
|
||||||
- Examples - <http://chessboardjs.com/examples>
|
|
||||||
|
|
||||||
## Developer Tools
|
|
||||||
|
|
||||||
```sh
|
|
||||||
# create a build in the build/ directory
|
|
||||||
npm run build
|
|
||||||
|
|
||||||
# re-build the website
|
|
||||||
npm run website
|
|
||||||
```
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
[MIT License](LICENSE.md)
|
|
||||||
|
|
||||||
[jQuery]:https://jquery.com/
|
|
||||||
[chessboardjs.com]:http://chessboardjs.com
|
|
||||||
[chess.js]:https://github.com/jhlywa/chess.js
|
|
||||||
[Example 5000]:http://chessboardjs.com/examples#5000
|
|
@ -1,29 +0,0 @@
|
|||||||
{
|
|
||||||
"author": "Chris Oakman <chris@oakmac.com> (http://chrisoakman.com/)",
|
|
||||||
"name": "@chrisoakman/chessboardjs",
|
|
||||||
"description": "JavaScript chessboard widget",
|
|
||||||
"homepage": "https://chessboardjs.com",
|
|
||||||
"license": "MIT",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"repository": {
|
|
||||||
"type": "git",
|
|
||||||
"url": "git://github.com/oakmac/chessboardjs.git"
|
|
||||||
},
|
|
||||||
"files": ["dist/"],
|
|
||||||
"dependencies": {
|
|
||||||
"jquery": ">=3.4.1"
|
|
||||||
},
|
|
||||||
"devDependencies": {
|
|
||||||
"csso": "3.5.1",
|
|
||||||
"fs-plus": "3.1.1",
|
|
||||||
"kidif": "1.1.0",
|
|
||||||
"mustache": "2.3.0",
|
|
||||||
"standard": "10.0.2",
|
|
||||||
"uglify-js": "3.6.0"
|
|
||||||
},
|
|
||||||
"scripts": {
|
|
||||||
"build": "standard lib/chessboard.js && node scripts/build.js",
|
|
||||||
"standard": "standard --fix lib/*.js website/js/*.js",
|
|
||||||
"website": "node scripts/website.js"
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,499 +0,0 @@
|
|||||||
<!doctype html>
|
|
||||||
<html lang="en-us">
|
|
||||||
<head>
|
|
||||||
<title>wchess : voice-controlled chess using Whisper + WebAssembly</title>
|
|
||||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
|
||||||
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=0.7, maximum-scale=1, minimum-scale=0.7, user-scalable=no"/>
|
|
||||||
<meta name="apple-mobile-web-app-capable" content="yes" />
|
|
||||||
|
|
||||||
<style>
|
|
||||||
#output {
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
margin: 0 auto;
|
|
||||||
margin-top: 10px;
|
|
||||||
border-left: 0px;
|
|
||||||
border-right: 0px;
|
|
||||||
padding-left: 0px;
|
|
||||||
padding-right: 0px;
|
|
||||||
display: block;
|
|
||||||
background-color: black;
|
|
||||||
color: white;
|
|
||||||
font-size: 10px;
|
|
||||||
font-family: 'Lucida Console', Monaco, monospace;
|
|
||||||
outline: none;
|
|
||||||
white-space: pre;
|
|
||||||
overflow-wrap: normal;
|
|
||||||
overflow-x: scroll;
|
|
||||||
}
|
|
||||||
.button {
|
|
||||||
background-color: #000000;
|
|
||||||
color: #FFFFFF;
|
|
||||||
padding: 20px;
|
|
||||||
border-radius: 10px;
|
|
||||||
-moz-border-radius: 10px;
|
|
||||||
-webkit-border-radius: 10px;
|
|
||||||
margin:10px;
|
|
||||||
width: 100px;
|
|
||||||
height: 50px;
|
|
||||||
-webkit-touch-callout: none; /* Safari */
|
|
||||||
-webkit-user-select: none; /* Chrome */
|
|
||||||
-moz-user-select: none; /* Firefox */
|
|
||||||
-ms-user-select: none; /* Internet Explorer/Edge */
|
|
||||||
user-select: none;
|
|
||||||
}
|
|
||||||
button[disabled]{
|
|
||||||
background-color: #cccccc;
|
|
||||||
color: #666666;
|
|
||||||
padding: 20px;
|
|
||||||
border-radius: 10px;
|
|
||||||
-moz-border-radius: 10px;
|
|
||||||
-webkit-border-radius: 10px;
|
|
||||||
margin:10px;
|
|
||||||
width: 100px;
|
|
||||||
}
|
|
||||||
.center {
|
|
||||||
display: flex;
|
|
||||||
justify-content: center;
|
|
||||||
align-items: center;
|
|
||||||
width: 500px;
|
|
||||||
}
|
|
||||||
#description {
|
|
||||||
width: 500px;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div id="main-container">
|
|
||||||
<div id="description">
|
|
||||||
<b>wchess : voice-controlled chess using Whisper + WebAssembly</b>
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
This is a demonstration of using Whisper to recognize voice commands in the browser.
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
Usage:<br>
|
|
||||||
|
|
||||||
<ul>
|
|
||||||
<li>Select a Whisper model</li>
|
|
||||||
<li>Accept the microphone permission request if prompted</li>
|
|
||||||
<li>Hold the button and say a chess move (e.g. "Knight to c3")</li>
|
|
||||||
<li>Release the button and wait for the move to be recognized</li>
|
|
||||||
<li>Repeat</li>
|
|
||||||
</ul>
|
|
||||||
|
|
||||||
Examples:<br>
|
|
||||||
|
|
||||||
<ul>
|
|
||||||
<li><b>"d4"</b></li>
|
|
||||||
<li><b>"e2 e4"</b></li>
|
|
||||||
<li><b>"Knight f3"</b></li>
|
|
||||||
<li><b>"Bishop to b5"</b></li>
|
|
||||||
</ul>
|
|
||||||
|
|
||||||
Features:<br>
|
|
||||||
|
|
||||||
<ul>
|
|
||||||
<li>Model quantization for reduced memory footprint (~42MB)</li>
|
|
||||||
<li><a href="https://github.com/ggerganov/whisper.cpp/pull/1229">Grammar-based sampling</a> for improved recognition accuracy</li>
|
|
||||||
</ul>
|
|
||||||
|
|
||||||
<b>
|
|
||||||
Note that not all chess moves are supported. For example, castling and pawn promotion
|
|
||||||
currently do not work, but can be easily implemented. There could also be some bugs in
|
|
||||||
the move handling logic in general. The main reason for that is to keep the implementation
|
|
||||||
simple. The assumption is that a real application would already have a proper move
|
|
||||||
validation logic in place.<br><br>
|
|
||||||
|
|
||||||
The main purpose of this example is to demonstrate the capabilities of whisper.cpp and
|
|
||||||
its application in the browser for voice recognition locally on your device.
|
|
||||||
</b>
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/wchess">GitHub</a>.
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
<b>More examples:</b>
|
|
||||||
<a href="https://whisper.ggerganov.com/">main</a> |
|
|
||||||
<a href="https://whisper.ggerganov.com/bench">bench</a> |
|
|
||||||
<a href="https://whisper.ggerganov.com/stream">stream</a> |
|
|
||||||
<a href="https://whisper.ggerganov.com/command">command</a> |
|
|
||||||
<a href="https://whisper.ggerganov.com/talk">talk</a> |
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<hr>
|
|
||||||
|
|
||||||
<div id="model-whisper">
|
|
||||||
Whisper model: <span id="model-whisper-status"></span>
|
|
||||||
<button id="fetch-whisper-tiny-en" onclick="loadWhisper()">tiny.en (Q8_0, 42 MB)</button>
|
|
||||||
<span id="fetch-whisper-progress"></span>
|
|
||||||
<br><br>
|
|
||||||
<button id="clear" onclick="clearCache()">Clear browser cache</button>
|
|
||||||
<!--
|
|
||||||
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
|
|
||||||
-->
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div id="game">
|
|
||||||
<br>
|
|
||||||
<div id="chessboard" style="width: 500px"></div>
|
|
||||||
<script src="js/jquery-3.7.1.min.js"></script>
|
|
||||||
<script src="js/chessboard-1.0.0.min.js"></script>
|
|
||||||
<script>
|
|
||||||
var board = Chessboard('chessboard', 'start')
|
|
||||||
var move_count = 0;
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
<div id="state">
|
|
||||||
Status: <b><span id="state-status">select model</span></b>
|
|
||||||
|
|
||||||
<div id="input" class="center">
|
|
||||||
<button id="toggler" class="button" onselectstart="return false" style="display: none">Hold</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<pre id="state-grammar">[The grammar will be displayed here]</pre>
|
|
||||||
|
|
||||||
<pre id="state-moves">[The moves will be displayed here]</pre>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<hr>
|
|
||||||
|
|
||||||
Debug output:
|
|
||||||
<textarea id="output" rows="20"></textarea>
|
|
||||||
|
|
||||||
<br>
|
|
||||||
|
|
||||||
<b>Troubleshooting</b>
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
The page does some heavy computations, so make sure:
|
|
||||||
|
|
||||||
<ul>
|
|
||||||
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
|
|
||||||
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
|
|
||||||
</ul>
|
|
||||||
|
|
||||||
<div class="cell-version">
|
|
||||||
<span>
|
|
||||||
|
|
|
||||||
Build time: <span class="nav-link">@GIT_DATE@</span> |
|
|
||||||
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
|
|
||||||
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
|
|
||||||
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">Source Code</a> |
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script type="text/javascript" src="js/helpers.js"></script>
|
|
||||||
<script type='text/javascript'>
|
|
||||||
// web audio context
|
|
||||||
var context = null;
|
|
||||||
|
|
||||||
// the command instance
|
|
||||||
var instance = null;
|
|
||||||
|
|
||||||
// model name
|
|
||||||
var model_whisper = null;
|
|
||||||
var model_file = null;
|
|
||||||
|
|
||||||
var module_ready = null;
|
|
||||||
|
|
||||||
var Module = {
|
|
||||||
print: printTextarea,
|
|
||||||
printErr: printTextarea,
|
|
||||||
setStatus: function(text) {
|
|
||||||
printTextarea('js: ' + text);
|
|
||||||
},
|
|
||||||
monitorRunDependencies: function(left) {
|
|
||||||
},
|
|
||||||
preRun: function() {
|
|
||||||
printTextarea('js: Preparing ...');
|
|
||||||
},
|
|
||||||
postRun: function() {
|
|
||||||
printTextarea('js: Module initialized successfully!');
|
|
||||||
module_ready = true;
|
|
||||||
initInstance();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
function initInstance() {
|
|
||||||
if (!module_ready || !model_file || instance) return
|
|
||||||
|
|
||||||
instance = Module.init(model_file);
|
|
||||||
|
|
||||||
if (instance) {
|
|
||||||
setStatus('Ready');
|
|
||||||
printTextarea("js: whisper initialized, instance: " + instance);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
printTextarea("js: failed to initialize whisper");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function setStatus(text) {
|
|
||||||
document.getElementById('state-status').innerHTML = text;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// fetch models
|
|
||||||
//
|
|
||||||
|
|
||||||
let dbVersion = 1
|
|
||||||
let dbName = 'whisper.ggerganov.com';
|
|
||||||
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
|
|
||||||
|
|
||||||
function storeFS(fname, buf) {
|
|
||||||
// write to WASM file using FS_createDataFile
|
|
||||||
// if the file exists, delete it
|
|
||||||
try {
|
|
||||||
Module.FS_unlink(fname);
|
|
||||||
} catch (e) {
|
|
||||||
// ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
Module.FS_createDataFile("/", fname, buf, true, true);
|
|
||||||
|
|
||||||
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
|
|
||||||
|
|
||||||
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
|
|
||||||
|
|
||||||
model_file = fname;
|
|
||||||
initInstance();
|
|
||||||
}
|
|
||||||
|
|
||||||
function loadWhisper() {
|
|
||||||
setStatus('Loading')
|
|
||||||
//let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
|
|
||||||
let url = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin';
|
|
||||||
let dst = 'whisper.bin';
|
|
||||||
let size_mb = 42;
|
|
||||||
|
|
||||||
model_whisper = 'tiny.en-q8_0';
|
|
||||||
|
|
||||||
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
|
|
||||||
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
|
|
||||||
|
|
||||||
cbProgress = function(p) {
|
|
||||||
let el = document.getElementById('fetch-whisper-progress');
|
|
||||||
el.innerHTML = Math.round(100*p) + '%';
|
|
||||||
};
|
|
||||||
|
|
||||||
cbCancel = function() {
|
|
||||||
var el;
|
|
||||||
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
|
|
||||||
};
|
|
||||||
|
|
||||||
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
|
|
||||||
|
|
||||||
// init audio capture so that the user receives a permission request
|
|
||||||
{
|
|
||||||
let context = new AudioContext({
|
|
||||||
sampleRate: 16000,
|
|
||||||
channelCount: 1,
|
|
||||||
echoCancellation: false,
|
|
||||||
autoGainControl: true,
|
|
||||||
noiseSuppression: true,
|
|
||||||
});
|
|
||||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
|
||||||
.then(function(s) {
|
|
||||||
stream = s;
|
|
||||||
stream.getTracks().forEach(function(track) {
|
|
||||||
track.stop();
|
|
||||||
});
|
|
||||||
})
|
|
||||||
.catch(function(err) {
|
|
||||||
printTextarea('js: error getting audio stream: ' + err);
|
|
||||||
});
|
|
||||||
context.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
document.getElementById('toggler').style.display = 'block';
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// microphone
|
|
||||||
//
|
|
||||||
|
|
||||||
const kSampleRate = 16000;
|
|
||||||
const kRestartRecording_s = 120;
|
|
||||||
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
|
|
||||||
|
|
||||||
var mediaRecorder = null;
|
|
||||||
var doRecording = false;
|
|
||||||
var startTime = 0;
|
|
||||||
|
|
||||||
window.AudioContext = window.AudioContext || window.webkitAudioContext;
|
|
||||||
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
|
|
||||||
|
|
||||||
function stopRecording() {
|
|
||||||
if (mediaRecorder) {
|
|
||||||
mediaRecorder.stop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function startRecording() {
|
|
||||||
if (!context) {
|
|
||||||
context = new AudioContext({
|
|
||||||
sampleRate: kSampleRate,
|
|
||||||
channelCount: 1,
|
|
||||||
echoCancellation: false,
|
|
||||||
autoGainControl: true,
|
|
||||||
noiseSuppression: true,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
startTime = Date.now();
|
|
||||||
|
|
||||||
var chunks = [];
|
|
||||||
var stream = null;
|
|
||||||
|
|
||||||
navigator.mediaDevices.getUserMedia({audio: true, video: false})
|
|
||||||
.then(function(s) {
|
|
||||||
stream = s;
|
|
||||||
mediaRecorder = new MediaRecorder(stream);
|
|
||||||
mediaRecorder.ondataavailable = function(e) {
|
|
||||||
chunks.push(e.data);
|
|
||||||
|
|
||||||
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
|
|
||||||
var reader = new FileReader();
|
|
||||||
|
|
||||||
reader.onload = function(event) {
|
|
||||||
var buf = new Uint8Array(reader.result);
|
|
||||||
context.decodeAudioData(buf.buffer, function(audioBuffer) {
|
|
||||||
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
|
|
||||||
var source = offlineContext.createBufferSource();
|
|
||||||
source.buffer = audioBuffer;
|
|
||||||
source.connect(offlineContext.destination);
|
|
||||||
source.start(0);
|
|
||||||
|
|
||||||
offlineContext.startRendering().then(function(renderedBuffer) {
|
|
||||||
let audio = renderedBuffer.getChannelData(0);
|
|
||||||
printTextarea('js: number of samples: ' + audio.length);
|
|
||||||
Module.set_audio(instance, audio);
|
|
||||||
});
|
|
||||||
|
|
||||||
mediaRecorder = null;
|
|
||||||
context = null;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
reader.readAsArrayBuffer(blob);
|
|
||||||
};
|
|
||||||
|
|
||||||
mediaRecorder.onstop = function(e) {
|
|
||||||
stream.getTracks().forEach(function(track) {
|
|
||||||
track.stop();
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
mediaRecorder.start();
|
|
||||||
})
|
|
||||||
.catch(function(err) {
|
|
||||||
printTextarea('js: error getting audio stream: ' + err);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// main
|
|
||||||
//
|
|
||||||
|
|
||||||
var nLines = 0;
|
|
||||||
var movesAll = '';
|
|
||||||
|
|
||||||
// document.body.addEventListener('keydown', function(event) {
|
|
||||||
// if (event.keyCode === 32) {
|
|
||||||
// document.getElementById('toggler').innerText = "";
|
|
||||||
// onStart();
|
|
||||||
// }
|
|
||||||
// }, true);
|
|
||||||
|
|
||||||
// document.body.addEventListener('keyup', function(event) {
|
|
||||||
// if (event.keyCode === 32) {
|
|
||||||
// document.getElementById('toggler').innerText = "Hold";
|
|
||||||
// onStop();
|
|
||||||
// }
|
|
||||||
// }, true);
|
|
||||||
|
|
||||||
document.getElementById('toggler').addEventListener("touchstart", function(event){
|
|
||||||
this.innerText = "";
|
|
||||||
onStart();
|
|
||||||
}, true);
|
|
||||||
|
|
||||||
document.getElementById('toggler').addEventListener("touchend", function(event){
|
|
||||||
this.innerText = "Hold";
|
|
||||||
onStop();
|
|
||||||
}, true)
|
|
||||||
|
|
||||||
document.getElementById('toggler').addEventListener('mousedown', function(event) {
|
|
||||||
this.innerText = "";
|
|
||||||
onStart();
|
|
||||||
}, true);
|
|
||||||
|
|
||||||
document.getElementById('toggler').addEventListener('mouseup', function(event) {
|
|
||||||
this.innerText = "Hold";
|
|
||||||
onStop();
|
|
||||||
}, true);
|
|
||||||
|
|
||||||
function onStart() {
|
|
||||||
if (!instance) return;
|
|
||||||
setStatus('Listening');
|
|
||||||
|
|
||||||
startRecording();
|
|
||||||
}
|
|
||||||
|
|
||||||
function onStop() {
|
|
||||||
setStatus('Processing');
|
|
||||||
printTextarea('js: stopping recording ...');
|
|
||||||
stopRecording();
|
|
||||||
}
|
|
||||||
|
|
||||||
function setMove(move, prob) {
|
|
||||||
if (move != null && move.length > 1) {
|
|
||||||
let gameOver = move[move.length - 1] === '#';
|
|
||||||
if (gameOver) {
|
|
||||||
move = move.substring(0, move.length - 1);
|
|
||||||
document.getElementById('toggler').disabled = true;
|
|
||||||
}
|
|
||||||
board.move(move);
|
|
||||||
|
|
||||||
movesAll += move + ', prob = ' + prob.toFixed(2) + '% <br>';
|
|
||||||
nLines++;
|
|
||||||
|
|
||||||
// if more than 10 lines, remove the first line
|
|
||||||
if (nLines > 10) {
|
|
||||||
var i = movesAll.indexOf('<br>');
|
|
||||||
if (i > 0) {
|
|
||||||
movesAll = movesAll.substring(i + 4);
|
|
||||||
nLines--;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++move_count;
|
|
||||||
setStatus(gameOver ? 'Done' : move_count % 2 ? 'Black\'s turn' : 'White\'s turn');
|
|
||||||
document.getElementById('state-moves').innerHTML = movesAll;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
setStatus('Failed. ' + (move_count % 2 ? 'Black\'s turn' : 'White\'s turn'));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function setGrammar(grammar) {
|
|
||||||
document.getElementById('state-grammar').innerHTML = grammar;
|
|
||||||
}
|
|
||||||
|
|
||||||
</script>
|
|
||||||
<script type="text/javascript" src="js/chess.js"></script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
@ -1,141 +0,0 @@
|
|||||||
#include <WChess.h>
|
|
||||||
#include <emscripten.h>
|
|
||||||
#include <emscripten/bind.h>
|
|
||||||
|
|
||||||
#include <thread>
|
|
||||||
|
|
||||||
constexpr int N_THREAD = 8;
|
|
||||||
|
|
||||||
std::vector<struct whisper_context *> g_contexts(4, nullptr);
|
|
||||||
|
|
||||||
std::mutex g_mutex;
|
|
||||||
std::thread g_worker;
|
|
||||||
|
|
||||||
std::condition_variable g_cv;
|
|
||||||
|
|
||||||
bool g_running(false);
|
|
||||||
std::vector<float> g_pcmf32;
|
|
||||||
|
|
||||||
void set_move(const std::string & move, float prob) {
|
|
||||||
MAIN_THREAD_EM_ASM({
|
|
||||||
setMove(UTF8ToString($0), $1)
|
|
||||||
}, move.c_str(), prob);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_grammar(const std::string & grammar) {
|
|
||||||
MAIN_THREAD_EM_ASM({
|
|
||||||
setGrammar(UTF8ToString($0))
|
|
||||||
}, grammar.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool get_audio(std::vector<float> & audio) {
|
|
||||||
std::unique_lock<std::mutex> lock(g_mutex);
|
|
||||||
g_cv.wait(lock, [] { return !g_running || !g_pcmf32.empty(); });
|
|
||||||
if (!g_running) return false;
|
|
||||||
audio = std::move(g_pcmf32);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void wchess_main(size_t i) {
|
|
||||||
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
|
|
||||||
|
|
||||||
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
|
||||||
wparams.offset_ms = 0;
|
|
||||||
wparams.translate = false;
|
|
||||||
wparams.no_context = true;
|
|
||||||
wparams.single_segment = true;
|
|
||||||
wparams.print_realtime = false;
|
|
||||||
wparams.print_progress = false;
|
|
||||||
wparams.print_timestamps = true;
|
|
||||||
wparams.print_special = false;
|
|
||||||
wparams.no_timestamps = true;
|
|
||||||
|
|
||||||
wparams.max_tokens = 32;
|
|
||||||
wparams.audio_ctx = 1280; // partial encoder context for better performance
|
|
||||||
|
|
||||||
wparams.temperature = 0.0f;
|
|
||||||
wparams.temperature_inc = 2.0f;
|
|
||||||
wparams.greedy.best_of = 1;
|
|
||||||
|
|
||||||
wparams.beam_search.beam_size = 1;
|
|
||||||
|
|
||||||
wparams.language = "en";
|
|
||||||
|
|
||||||
wparams.grammar_penalty = 100.0;
|
|
||||||
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
|
|
||||||
|
|
||||||
printf("command: using %d threads\n", wparams.n_threads);
|
|
||||||
|
|
||||||
WChess::callbacks cb;
|
|
||||||
cb.get_audio = get_audio;
|
|
||||||
cb.set_move = set_move;
|
|
||||||
cb.set_grammar = set_grammar;
|
|
||||||
|
|
||||||
WChess(g_contexts[i], wparams, cb, {}).run();
|
|
||||||
|
|
||||||
if (i < g_contexts.size()) {
|
|
||||||
whisper_free(g_contexts[i]);
|
|
||||||
g_contexts[i] = nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
EMSCRIPTEN_BINDINGS(command) {
|
|
||||||
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
|
|
||||||
for (size_t i = 0; i < g_contexts.size(); ++i) {
|
|
||||||
if (g_contexts[i] == nullptr) {
|
|
||||||
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
|
|
||||||
if (g_contexts[i] != nullptr) {
|
|
||||||
g_running = true;
|
|
||||||
if (g_worker.joinable()) {
|
|
||||||
g_worker.join();
|
|
||||||
}
|
|
||||||
g_worker = std::thread([i]() {
|
|
||||||
wchess_main(i);
|
|
||||||
});
|
|
||||||
|
|
||||||
return i + 1;
|
|
||||||
} else {
|
|
||||||
return (size_t) 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (size_t) 0;
|
|
||||||
}));
|
|
||||||
|
|
||||||
emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
|
|
||||||
{
|
|
||||||
std::unique_lock<std::mutex> lock(g_mutex);
|
|
||||||
g_running = false;
|
|
||||||
}
|
|
||||||
g_cv.notify_one();
|
|
||||||
}));
|
|
||||||
|
|
||||||
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
|
|
||||||
--index;
|
|
||||||
|
|
||||||
if (index >= g_contexts.size()) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (g_contexts[index] == nullptr) {
|
|
||||||
return -2;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(g_mutex);
|
|
||||||
const int n = audio["length"].as<int>();
|
|
||||||
|
|
||||||
emscripten::val heap = emscripten::val::module_property("HEAPU8");
|
|
||||||
emscripten::val memory = heap["buffer"];
|
|
||||||
|
|
||||||
g_pcmf32.resize(n);
|
|
||||||
|
|
||||||
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
|
|
||||||
memoryView.call<void>("set", audio);
|
|
||||||
}
|
|
||||||
g_cv.notify_one();
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}));
|
|
||||||
}
|
|
@ -206,7 +206,6 @@ void AudioInputCallback(void * inUserData,
|
|||||||
params.offset_ms = 0;
|
params.offset_ms = 0;
|
||||||
params.no_context = true;
|
params.no_context = true;
|
||||||
params.single_segment = self->stateInp.isRealtime;
|
params.single_segment = self->stateInp.isRealtime;
|
||||||
params.no_timestamps = params.single_segment;
|
|
||||||
|
|
||||||
CFTimeInterval startTime = CACurrentMediaTime();
|
CFTimeInterval startTime = CACurrentMediaTime();
|
||||||
|
|
||||||
|
@ -61,9 +61,7 @@ models = [
|
|||||||
"ggml-small.bin",
|
"ggml-small.bin",
|
||||||
"ggml-medium.en.bin",
|
"ggml-medium.en.bin",
|
||||||
"ggml-medium.bin",
|
"ggml-medium.bin",
|
||||||
"ggml-large-v1.bin",
|
"ggml-large.bin",
|
||||||
"ggml-large-v2.bin",
|
|
||||||
"ggml-large-v3.bin",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,154 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# Synchronize ggml changes to whisper.cpp
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
#
|
|
||||||
# $ cd /path/to/whisper.cpp
|
|
||||||
# $ ./extra/sync-ggml-am.sh
|
|
||||||
#
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
sd=$(dirname $0)
|
|
||||||
cd $sd/../
|
|
||||||
|
|
||||||
SRC_WHISPER=$(pwd)
|
|
||||||
SRC_GGML=$(cd ../ggml; pwd)
|
|
||||||
|
|
||||||
if [ ! -d $SRC_GGML ]; then
|
|
||||||
echo "ggml not found at $SRC_GGML"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
lc=$(cat $SRC_WHISPER/extra/sync-ggml.last)
|
|
||||||
echo "Syncing ggml changes since commit $lc"
|
|
||||||
|
|
||||||
cd $SRC_GGML
|
|
||||||
|
|
||||||
git log --oneline $lc..HEAD
|
|
||||||
git log --oneline $lc..HEAD --reverse | grep -v "(whisper/[0-9]*)" | cut -d' ' -f1 > $SRC_WHISPER/ggml-commits
|
|
||||||
|
|
||||||
if [ ! -s $SRC_WHISPER/ggml-commits ]; then
|
|
||||||
rm -v $SRC_WHISPER/ggml-commits
|
|
||||||
echo "No new commits"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -f $SRC_WHISPER/ggml-src.patch ]; then
|
|
||||||
rm -v $SRC_WHISPER/ggml-src.patch
|
|
||||||
fi
|
|
||||||
|
|
||||||
while read c; do
|
|
||||||
git format-patch -k $c~1..$c --stdout -- \
|
|
||||||
include/ggml/ggml*.h \
|
|
||||||
src/ggml*.h \
|
|
||||||
src/ggml*.c \
|
|
||||||
src/ggml*.cpp \
|
|
||||||
src/ggml*.m \
|
|
||||||
src/ggml*.metal \
|
|
||||||
src/ggml*.cu \
|
|
||||||
examples/common.h \
|
|
||||||
examples/common.cpp \
|
|
||||||
examples/common-ggml.h \
|
|
||||||
examples/common-ggml.cpp \
|
|
||||||
examples/whisper/whisper.h \
|
|
||||||
examples/whisper/whisper.cpp \
|
|
||||||
examples/whisper/main.cpp \
|
|
||||||
examples/whisper/quantize.cpp \
|
|
||||||
>> $SRC_WHISPER/ggml-src.patch
|
|
||||||
done < $SRC_WHISPER/ggml-commits
|
|
||||||
|
|
||||||
rm -v $SRC_WHISPER/ggml-commits
|
|
||||||
|
|
||||||
# delete files if empty
|
|
||||||
if [ ! -s $SRC_WHISPER/ggml-src.patch ]; then
|
|
||||||
rm -v $SRC_WHISPER/ggml-src.patch
|
|
||||||
fi
|
|
||||||
|
|
||||||
cd $SRC_WHISPER
|
|
||||||
|
|
||||||
if [ -f $SRC_WHISPER/ggml-src.patch ]; then
|
|
||||||
# replace PR numbers
|
|
||||||
#
|
|
||||||
# Subject: some text (#1234)
|
|
||||||
# Subject: some text (ggml/1234)
|
|
||||||
cat ggml-src.patch | sed -e 's/^Subject: \(.*\) (#\([0-9]*\))/Subject: \1 (ggml\/\2)/' > ggml-src.patch.tmp
|
|
||||||
mv ggml-src.patch.tmp ggml-src.patch
|
|
||||||
|
|
||||||
cat ggml-src.patch | sed -e 's/^\(.*\) (#\([0-9]*\))$/\1 (ggml\/\2)/' > ggml-src.patch.tmp
|
|
||||||
mv ggml-src.patch.tmp ggml-src.patch
|
|
||||||
|
|
||||||
# replace filenames:
|
|
||||||
#
|
|
||||||
# src/ggml.c -> ggml.c
|
|
||||||
# src/ggml-alloc.c -> ggml-alloc.c
|
|
||||||
# src/ggml-backend-impl.h -> ggml-backend-impl.h
|
|
||||||
# src/ggml-backend.c -> ggml-backend.c
|
|
||||||
# src/ggml-cuda.cu -> ggml-cuda.cu
|
|
||||||
# src/ggml-cuda.h -> ggml-cuda.h
|
|
||||||
# src/ggml-impl.h -> ggml-impl.h
|
|
||||||
# src/ggml-metal.h -> ggml-metal.h
|
|
||||||
# src/ggml-metal.m -> ggml-metal.m
|
|
||||||
# src/ggml-mpi.h -> ggml-mpi.h
|
|
||||||
# src/ggml-mpi.c -> ggml-mpi.c
|
|
||||||
# src/ggml-opencl.cpp -> ggml-opencl.cpp
|
|
||||||
# src/ggml-opencl.h -> ggml-opencl.h
|
|
||||||
# src/ggml-quants.c -> ggml-quants.c
|
|
||||||
# src/ggml-quants.h -> ggml-quants.h
|
|
||||||
# include/ggml/ggml.h -> ggml.h
|
|
||||||
# include/ggml/ggml-alloc.h -> ggml-alloc.h
|
|
||||||
# include/ggml/ggml-backend.h -> ggml-backend.h
|
|
||||||
#
|
|
||||||
# examples/common.h -> examples/common.h
|
|
||||||
# examples/common.cpp -> examples/common.cpp
|
|
||||||
# examples/common-ggml.h -> examples/common-ggml.h
|
|
||||||
# examples/common-ggml.cpp -> examples/common-ggml.cpp
|
|
||||||
#
|
|
||||||
# examples/whisper/whisper.h -> whisper.h
|
|
||||||
# examples/whisper/whisper.cpp -> whisper.cpp
|
|
||||||
# examples/whisper/main.cpp -> examples/main/main.cpp
|
|
||||||
# examples/whisper/quantize.cpp -> examples/quantize/quantize.cpp
|
|
||||||
|
|
||||||
cat ggml-src.patch | sed \
|
|
||||||
-e 's/src\/ggml\.c/ggml.c/g' \
|
|
||||||
-e 's/src\/ggml-alloc\.c/ggml-alloc.c/g' \
|
|
||||||
-e 's/src\/ggml-backend-impl\.h/ggml-backend-impl.h/g' \
|
|
||||||
-e 's/src\/ggml-backend\.c/ggml-backend.c/g' \
|
|
||||||
-e 's/src\/ggml-cuda\.cu/ggml-cuda.cu/g' \
|
|
||||||
-e 's/src\/ggml-cuda\.h/ggml-cuda.h/g' \
|
|
||||||
-e 's/src\/ggml-impl\.h/ggml-impl.h/g' \
|
|
||||||
-e 's/src\/ggml-metal\.h/ggml-metal.h/g' \
|
|
||||||
-e 's/src\/ggml-metal\.m/ggml-metal.m/g' \
|
|
||||||
-e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \
|
|
||||||
-e 's/src\/ggml-mpi\.c/ggml-mpi.c/g' \
|
|
||||||
-e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \
|
|
||||||
-e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \
|
|
||||||
-e 's/src\/ggml-quants\.c/ggml-quants.c/g' \
|
|
||||||
-e 's/src\/ggml-quants\.h/ggml-quants.h/g' \
|
|
||||||
-e 's/include\/ggml\/ggml\.h/ggml.h/g' \
|
|
||||||
-e 's/include\/ggml\/ggml-alloc\.h/ggml-alloc.h/g' \
|
|
||||||
-e 's/include\/ggml\/ggml-backend\.h/ggml-backend.h/g' \
|
|
||||||
-e 's/examples\/common\.h/examples\/common.h/g' \
|
|
||||||
-e 's/examples\/common\.cpp/examples\/common.cpp/g' \
|
|
||||||
-e 's/examples\/common-ggml\.h/examples\/common-ggml.h/g' \
|
|
||||||
-e 's/examples\/common-ggml\.cpp/examples\/common-ggml.cpp/g' \
|
|
||||||
-e 's/examples\/whisper\/whisper\.h/whisper.h/g' \
|
|
||||||
-e 's/examples\/whisper\/whisper\.cpp/whisper.cpp/g' \
|
|
||||||
-e 's/examples\/whisper\/main\.cpp/examples\/main\/main.cpp/g' \
|
|
||||||
-e 's/examples\/whisper\/quantize\.cpp/examples\/quantize\/quantize.cpp/g' \
|
|
||||||
> ggml-src.patch.tmp
|
|
||||||
mv ggml-src.patch.tmp ggml-src.patch
|
|
||||||
|
|
||||||
git am ggml-src.patch
|
|
||||||
|
|
||||||
rm -v $SRC_WHISPER/ggml-src.patch
|
|
||||||
fi
|
|
||||||
|
|
||||||
# update last commit
|
|
||||||
cd $SRC_GGML
|
|
||||||
git log -1 --format=%H > $SRC_WHISPER/extra/sync-ggml.last
|
|
||||||
|
|
||||||
echo "Done"
|
|
||||||
|
|
||||||
exit 0
|
|
@ -1 +0,0 @@
|
|||||||
3fd01e00e40583ccd4b393a7c6502d6a4455a1d5
|
|
@ -1,5 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp -rpv ../llama.cpp/llama.h ./examples/talk-llama/llama.h
|
|
||||||
cp -rpv ../llama.cpp/llama.cpp ./examples/talk-llama/llama.cpp
|
|
||||||
cp -rpv ../llama.cpp/unicode.h ./examples/talk-llama/unicode.h
|
|
67
ggml-alloc.c
@ -72,7 +72,7 @@ static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * t
|
|||||||
|
|
||||||
// check if a tensor is allocated by this buffer
|
// check if a tensor is allocated by this buffer
|
||||||
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
|
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
|
||||||
return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
|
return tensor->buffer == alloc->buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||||
@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
|
|||||||
|
|
||||||
#ifdef GGML_ALLOCATOR_DEBUG
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
add_allocated_tensor(alloc, tensor);
|
add_allocated_tensor(alloc, tensor);
|
||||||
size_t cur_max = (char*)addr - (char*)alloc->base + size;
|
size_t cur_max = (char*)addr - (char*)alloc->data + size;
|
||||||
if (cur_max > alloc->max_size) {
|
if (cur_max > alloc->max_size) {
|
||||||
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
@ -168,6 +168,10 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
|
|||||||
size = aligned_offset(NULL, size, alloc->alignment);
|
size = aligned_offset(NULL, size, alloc->alignment);
|
||||||
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
|
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
|
||||||
|
|
||||||
|
if (!alloc->measure) {
|
||||||
|
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef GGML_ALLOCATOR_DEBUG
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
remove_allocated_tensor(alloc, tensor);
|
remove_allocated_tensor(alloc, tensor);
|
||||||
#endif
|
#endif
|
||||||
@ -233,7 +237,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
|
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
|
||||||
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
|
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
|
||||||
|
|
||||||
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
|
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
|
||||||
|
|
||||||
@ -445,15 +449,17 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
|
|||||||
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
|
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
|
||||||
ggml_tallocr_t alloc = node_tallocr(galloc, view);
|
ggml_tallocr_t alloc = node_tallocr(galloc, view);
|
||||||
|
|
||||||
|
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
|
||||||
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
|
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
|
||||||
if (update_backend) {
|
if (update_backend) {
|
||||||
view->backend = view->view_src->backend;
|
view->backend = view->view_src->backend;
|
||||||
}
|
}
|
||||||
// views are initialized in the alloc buffer rather than the view_src buffer
|
view->buffer = view->view_src->buffer;
|
||||||
view->buffer = alloc->buffer;
|
|
||||||
view->data = (char *)view->view_src->data + view->view_offs;
|
view->data = (char *)view->view_src->data + view->view_offs;
|
||||||
|
|
||||||
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
|
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
|
||||||
|
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
|
||||||
|
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
|
||||||
|
|
||||||
if (!alloc->measure) {
|
if (!alloc->measure) {
|
||||||
ggml_backend_buffer_init_tensor(alloc->buffer, view);
|
ggml_backend_buffer_init_tensor(alloc->buffer, view);
|
||||||
@ -735,10 +741,6 @@ void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ggml_allocr_free(ggml_allocr_t alloc) {
|
void ggml_allocr_free(ggml_allocr_t alloc) {
|
||||||
if (alloc == NULL) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_gallocr_free(alloc->galloc);
|
ggml_gallocr_free(alloc->galloc);
|
||||||
ggml_tallocr_free(alloc->talloc);
|
ggml_tallocr_free(alloc->talloc);
|
||||||
free(alloc);
|
free(alloc);
|
||||||
@ -763,48 +765,3 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
|
|||||||
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
|
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
|
||||||
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
|
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// utils
|
|
||||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
|
||||||
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
|
|
||||||
|
|
||||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
|
||||||
|
|
||||||
size_t nbytes = 0;
|
|
||||||
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
||||||
if (t->data == NULL && t->view_src == NULL) {
|
|
||||||
nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nbytes == 0) {
|
|
||||||
// all the tensors in the context are already allocated
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
|
|
||||||
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
|
|
||||||
|
|
||||||
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
|
||||||
if (t->data == NULL) {
|
|
||||||
if (t->view_src == NULL) {
|
|
||||||
ggml_tallocr_alloc(tallocr, t);
|
|
||||||
} else {
|
|
||||||
ggml_backend_view_init(buffer, t);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (t->view_src != NULL) {
|
|
||||||
// view of a pre-allocated tensor
|
|
||||||
ggml_backend_view_init(buffer, t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tallocr_free(tallocr);
|
|
||||||
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
|
|
||||||
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
|
|
||||||
}
|
|
||||||
|
@ -8,7 +8,6 @@ extern "C" {
|
|||||||
|
|
||||||
struct ggml_backend;
|
struct ggml_backend;
|
||||||
struct ggml_backend_buffer;
|
struct ggml_backend_buffer;
|
||||||
struct ggml_backend_buffer_type;
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Legacy API
|
// Legacy API
|
||||||
@ -43,7 +42,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
|
|||||||
// ggml-backend v2 API
|
// ggml-backend v2 API
|
||||||
//
|
//
|
||||||
|
|
||||||
// Separate tensor and graph allocator objects
|
// Seperate tensor and graph allocator objects
|
||||||
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
|
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
|
||||||
// The original API is kept as a wrapper around the new API
|
// The original API is kept as a wrapper around the new API
|
||||||
|
|
||||||
@ -81,12 +80,6 @@ GGML_API void ggml_gallocr_alloc_graph_n(
|
|||||||
struct ggml_hash_set hash_set,
|
struct ggml_hash_set hash_set,
|
||||||
ggml_tallocr_t * hash_node_talloc);
|
ggml_tallocr_t * hash_node_talloc);
|
||||||
|
|
||||||
|
|
||||||
// Utils
|
|
||||||
// Create a buffer and allocate all the tensors in a ggml_context
|
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
|
|
||||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -12,54 +12,31 @@ extern "C" {
|
|||||||
// Backend buffer
|
// Backend buffer
|
||||||
//
|
//
|
||||||
|
|
||||||
// buffer type
|
|
||||||
typedef void * ggml_backend_buffer_type_context_t;
|
|
||||||
|
|
||||||
struct ggml_backend_buffer_type_i {
|
|
||||||
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
|
|
||||||
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
|
|
||||||
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
|
|
||||||
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
|
|
||||||
// check if tensor data is in host memory
|
|
||||||
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
|
|
||||||
bool (*is_host) (ggml_backend_buffer_type_t buft);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_backend_buffer_type {
|
|
||||||
struct ggml_backend_buffer_type_i iface;
|
|
||||||
ggml_backend_buffer_type_context_t context;
|
|
||||||
};
|
|
||||||
|
|
||||||
// buffer
|
|
||||||
typedef void * ggml_backend_buffer_context_t;
|
typedef void * ggml_backend_buffer_context_t;
|
||||||
|
|
||||||
struct ggml_backend_buffer_i {
|
struct ggml_backend_buffer_i {
|
||||||
void (*free_buffer) (ggml_backend_buffer_t buffer);
|
void (*free_buffer) (ggml_backend_buffer_t buffer);
|
||||||
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
|
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
|
||||||
void * (*get_base) (ggml_backend_buffer_t buffer);
|
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
|
||||||
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
|
||||||
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
|
||||||
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
|
||||||
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
|
|
||||||
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_buffer {
|
struct ggml_backend_buffer {
|
||||||
struct ggml_backend_buffer_i iface;
|
struct ggml_backend_buffer_i iface;
|
||||||
ggml_backend_buffer_type_t buft;
|
|
||||||
|
ggml_backend_t backend;
|
||||||
ggml_backend_buffer_context_t context;
|
ggml_backend_buffer_context_t context;
|
||||||
|
|
||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_t ggml_backend_buffer_init(
|
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
|
||||||
ggml_backend_buffer_type_t buft,
|
struct ggml_backend * backend,
|
||||||
struct ggml_backend_buffer_i iface,
|
struct ggml_backend_buffer_i iface,
|
||||||
ggml_backend_buffer_context_t context,
|
ggml_backend_buffer_context_t context,
|
||||||
size_t size);
|
size_t size);
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend
|
// Backend
|
||||||
//
|
//
|
||||||
@ -72,25 +49,28 @@ extern "C" {
|
|||||||
void (*free)(ggml_backend_t backend);
|
void (*free)(ggml_backend_t backend);
|
||||||
|
|
||||||
// buffer allocation
|
// buffer allocation
|
||||||
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
|
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
|
||||||
|
|
||||||
// (optional) asynchroneous tensor data access
|
// get buffer alignment
|
||||||
|
size_t (*get_alignment)(ggml_backend_t backend);
|
||||||
|
|
||||||
|
// tensor data access
|
||||||
|
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
|
||||||
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
|
|
||||||
// (optional) asynchroneous tensor copy
|
|
||||||
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
|
||||||
|
|
||||||
void (*synchronize) (ggml_backend_t backend);
|
void (*synchronize) (ggml_backend_t backend);
|
||||||
|
|
||||||
|
// (optional) copy tensor between different backends, allow for single-copy tranfers
|
||||||
|
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
|
||||||
// compute graph with a plan
|
// compute graph with a plan
|
||||||
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
|
|
||||||
// compute graph without a plan
|
// compute graph without a plan
|
||||||
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
|
||||||
// check if the backend supports an operation
|
// check if the backend supports an operation
|
||||||
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
|
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
|
||||||
@ -102,15 +82,6 @@ extern "C" {
|
|||||||
ggml_backend_context_t context;
|
ggml_backend_context_t context;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend registry
|
|
||||||
//
|
|
||||||
|
|
||||||
typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
|
|
||||||
|
|
||||||
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
835
ggml-backend.c
@ -7,47 +7,41 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
|
|
||||||
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
|
||||||
typedef struct ggml_backend * ggml_backend_t;
|
|
||||||
typedef void * ggml_backend_graph_plan_t;
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend buffer
|
// Backend buffer
|
||||||
//
|
//
|
||||||
|
|
||||||
// buffer type
|
struct ggml_backend_buffer;
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
|
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
||||||
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
|
|
||||||
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
|
|
||||||
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
|
|
||||||
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
|
|
||||||
|
|
||||||
// buffer
|
// backend buffer functions
|
||||||
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
||||||
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
||||||
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
||||||
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
|
||||||
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
|
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
|
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend
|
// Backend
|
||||||
//
|
//
|
||||||
|
|
||||||
|
struct ggml_backend;
|
||||||
|
typedef struct ggml_backend * ggml_backend_t;
|
||||||
|
typedef void * ggml_backend_graph_plan_t;
|
||||||
|
|
||||||
|
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
|
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
|
||||||
GGML_API void ggml_backend_free(ggml_backend_t backend);
|
GGML_API void ggml_backend_free(ggml_backend_t backend);
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
|
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
|
||||||
|
|
||||||
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
|
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
|
||||||
|
|
||||||
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
|
|
||||||
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
@ -58,12 +52,11 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
|
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
|
||||||
|
|
||||||
// tensor copy between different backends
|
// tensor copy between different backends
|
||||||
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
|
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// CPU backend
|
// CPU backend
|
||||||
@ -75,27 +68,8 @@ extern "C" {
|
|||||||
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
|
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
|
||||||
|
|
||||||
// Create a backend buffer from an existing pointer
|
// Create a backend buffer from an existing pointer
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CPU_HBM
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//
|
|
||||||
// Backend registry
|
|
||||||
//
|
|
||||||
|
|
||||||
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
|
|
||||||
|
|
||||||
GGML_API size_t ggml_backend_reg_get_count(void);
|
|
||||||
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
|
|
||||||
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
|
|
||||||
GGML_API const char * ggml_backend_reg_get_name(size_t i);
|
|
||||||
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend scheduler
|
// Backend scheduler
|
||||||
@ -157,32 +131,6 @@ extern "C" {
|
|||||||
ggml_backend_sched_t sched,
|
ggml_backend_sched_t sched,
|
||||||
struct ggml_cgraph * graph);
|
struct ggml_cgraph * graph);
|
||||||
|
|
||||||
|
|
||||||
//
|
|
||||||
// Utils
|
|
||||||
//
|
|
||||||
|
|
||||||
struct ggml_backend_graph_copy {
|
|
||||||
ggml_backend_buffer_t buffer;
|
|
||||||
struct ggml_context * ctx_allocated;
|
|
||||||
struct ggml_context * ctx_unallocated;
|
|
||||||
struct ggml_cgraph * graph;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Copy a graph to a different backend
|
|
||||||
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
|
|
||||||
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
|
|
||||||
|
|
||||||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
|
||||||
|
|
||||||
// Compare the output of two backends
|
|
||||||
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
|
|
||||||
|
|
||||||
// Tensor initialization
|
|
||||||
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
|
||||||
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
3605
ggml-cuda.cu
10
ggml-cuda.h
@ -49,15 +49,7 @@ GGML_API int ggml_cuda_get_device_count(void);
|
|||||||
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||||
|
|
||||||
// backend API
|
// backend API
|
||||||
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
|
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
|
||||||
|
|
||||||
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
|
||||||
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
|
||||||
|
|
||||||
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml
|
|||||||
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
|
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
|
||||||
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
||||||
|
|
||||||
// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
|
// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
|
||||||
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
|
||||||
|
|
||||||
// return index, asserts if table is full
|
// return index, asserts if table is full
|
||||||
|
11
ggml-metal.h
@ -87,7 +87,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
|
|||||||
|
|
||||||
// same as ggml_graph_compute but uses Metal
|
// same as ggml_graph_compute but uses Metal
|
||||||
// creates gf->n_threads command buffers in parallel
|
// creates gf->n_threads command buffers in parallel
|
||||||
bool ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
||||||
|
|
||||||
//
|
//
|
||||||
// backend API
|
// backend API
|
||||||
@ -98,17 +98,8 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
|
|||||||
|
|
||||||
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
||||||
|
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
|
||||||
|
|
||||||
// helper to check if the device supports a specific family
|
|
||||||
// ideally, the user code should be doing these checks
|
|
||||||
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
||||||
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
1545
ggml-metal.m
2935
ggml-metal.metal
@ -1,18 +1,20 @@
|
|||||||
#include "ggml.h"
|
|
||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstring>
|
|
||||||
#include <limits>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#define CL_TARGET_OPENCL_VERSION 110
|
#define CL_TARGET_OPENCL_VERSION 110
|
||||||
#include <clblast.h>
|
#include <clblast.h>
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
@ -6,19 +6,19 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
GGML_API void ggml_cl_init(void);
|
void ggml_cl_init(void);
|
||||||
|
|
||||||
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
||||||
|
|
||||||
GGML_API void * ggml_cl_host_malloc(size_t size);
|
void * ggml_cl_host_malloc(size_t size);
|
||||||
GGML_API void ggml_cl_host_free(void * ptr);
|
void ggml_cl_host_free(void * ptr);
|
||||||
|
|
||||||
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
|
void ggml_cl_free_data(const struct ggml_tensor* tensor);
|
||||||
|
|
||||||
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
|
void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
481
ggml-quants.c
@ -19,7 +19,7 @@
|
|||||||
#ifdef __wasm_simd128__
|
#ifdef __wasm_simd128__
|
||||||
#include <wasm_simd128.h>
|
#include <wasm_simd128.h>
|
||||||
#else
|
#else
|
||||||
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
|
#ifdef __POWER9_VECTOR__
|
||||||
#include <altivec.h>
|
#include <altivec.h>
|
||||||
#undef bool
|
#undef bool
|
||||||
#define bool _Bool
|
#define bool _Bool
|
||||||
@ -407,22 +407,6 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
|||||||
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if !defined(__ARM_FEATURE_DOTPROD)
|
|
||||||
|
|
||||||
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
|
||||||
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
|
||||||
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
|
||||||
|
|
||||||
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
|
||||||
@ -2484,12 +2468,32 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
|
|||||||
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
||||||
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
// dot product into int32x4_t
|
// dot product into int32x4_t
|
||||||
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
||||||
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
#else
|
||||||
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
||||||
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||||
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
|
||||||
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
|
||||||
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||||
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
|
||||||
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||||
|
|
||||||
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
|
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||||
|
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||||
|
|
||||||
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||||
@ -2772,12 +2776,32 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
|
|||||||
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
||||||
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
// dot product into int32x4_t
|
// dot product into int32x4_t
|
||||||
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
||||||
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
||||||
|
|
||||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
||||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
||||||
|
#else
|
||||||
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
|
||||||
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
|
||||||
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
|
||||||
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
|
||||||
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
|
||||||
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
|
||||||
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
|
||||||
|
|
||||||
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
|
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||||
|
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||||
|
|
||||||
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
||||||
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
||||||
@ -2939,12 +2963,32 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|||||||
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
||||||
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
#else
|
||||||
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
||||||
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
||||||
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
||||||
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
||||||
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
||||||
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
||||||
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
||||||
|
|
||||||
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
|
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||||
|
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||||
|
|
||||||
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||||
@ -3070,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
||||||
|
|
||||||
// These temporary registers are for masking and shift operations
|
// These tempory registers are for masking and shift operations
|
||||||
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
||||||
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
||||||
|
|
||||||
@ -3231,12 +3275,32 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
|
|||||||
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
||||||
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
||||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
||||||
|
#else
|
||||||
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
|
||||||
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
|
||||||
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
|
||||||
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
|
||||||
|
|
||||||
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
|
||||||
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
|
||||||
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
|
||||||
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
|
||||||
|
|
||||||
|
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||||
|
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||||
|
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||||
|
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||||
|
|
||||||
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
|
||||||
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
||||||
@ -3486,13 +3550,34 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
|
|||||||
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
||||||
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
|
|
||||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
||||||
ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
|
||||||
|
#else
|
||||||
|
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
|
||||||
|
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
|
||||||
|
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
|
||||||
|
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||||
|
|
||||||
|
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
|
||||||
|
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||||
|
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
|
||||||
|
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||||
|
|
||||||
|
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||||
|
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||||
|
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||||
|
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||||
|
|
||||||
|
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
|
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||||
@ -3565,10 +3650,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
||||||
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_int8x16x2_t q2bytes;
|
ggml_int8x16x2_t q2bytes;
|
||||||
uint8_t aux[16];
|
uint8_t aux[16];
|
||||||
@ -3576,6 +3663,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||||
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
||||||
|
|
||||||
@ -3589,7 +3677,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
||||||
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
||||||
const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
|
const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
|
||||||
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
||||||
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
||||||
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
||||||
@ -3601,9 +3689,20 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
// We use this macro instead of a function call because for some reason
|
// We use this macro instead of a function call because for some reason
|
||||||
// the code runs 2-3% slower, even if the function is declared inline
|
// the code runs 2-3% slower, even if the function is declared inline
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
||||||
|
#else
|
||||||
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
||||||
|
{\
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
||||||
|
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
||||||
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
|
||||||
@ -3611,23 +3710,26 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
||||||
MULTIPLY_ACCUM_WITH_SCALE((index));
|
MULTIPLY_ACCUM_WITH_SCALE((index));
|
||||||
|
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
|
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
|
||||||
|
|
||||||
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
||||||
|
|
||||||
MULTIPLY_ACCUM_WITH_SCALE(0);
|
MULTIPLY_ACCUM_WITH_SCALE(0);
|
||||||
|
|
||||||
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
||||||
|
|
||||||
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
||||||
|
|
||||||
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
||||||
|
|
||||||
is += 8;
|
is += 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
sum += d * isum;
|
sum += d * isum;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sum;
|
*s = sum;
|
||||||
@ -3941,9 +4043,11 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
|
||||||
|
|
||||||
|
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_int8x16x4_t q2bytes;
|
ggml_int8x16x4_t q2bytes;
|
||||||
|
|
||||||
@ -3977,12 +4081,28 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
|
||||||
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
|
||||||
|
|
||||||
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
|
||||||
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
|
||||||
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
|
||||||
|
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
|
||||||
|
#else
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
isum1 += vaddvq_s16(p1) * scales[0];
|
||||||
|
isum2 += vaddvq_s16(p2) * scales[1];
|
||||||
|
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum1 += vaddvq_s16(p3) * scales[2];
|
||||||
|
isum2 += vaddvq_s16(p4) * scales[3];
|
||||||
|
#endif
|
||||||
sum += d * (isum1 + isum2);
|
sum += d * (isum1 + isum2);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sum;
|
*s = sum;
|
||||||
@ -4208,7 +4328,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
uint32_t utmp[4];
|
uint32_t utmp[4];
|
||||||
|
|
||||||
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint8x16_t m0 = vdupq_n_u8(1);
|
const uint8x16_t m0 = vdupq_n_u8(1);
|
||||||
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
||||||
@ -4260,11 +4382,22 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
||||||
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
||||||
|
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
||||||
|
#else
|
||||||
|
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
||||||
|
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
||||||
|
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
||||||
|
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
||||||
|
#endif
|
||||||
scale += 4;
|
scale += 4;
|
||||||
|
|
||||||
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
||||||
@ -4277,11 +4410,22 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
||||||
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
||||||
|
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
||||||
|
#else
|
||||||
|
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
||||||
|
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
||||||
|
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
||||||
|
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
||||||
|
#endif
|
||||||
scale += 4;
|
scale += 4;
|
||||||
|
|
||||||
if (j == 0) {
|
if (j == 0) {
|
||||||
@ -4613,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
vl = 16;
|
vl = 16;
|
||||||
|
|
||||||
// retrieve lane to multiply with scale
|
// retreive lane to multiply with scale
|
||||||
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
||||||
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
||||||
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
||||||
@ -4720,7 +4864,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
||||||
const uint8x16_t mh = vdupq_n_u8(4);
|
const uint8x16_t mh = vdupq_n_u8(4);
|
||||||
@ -4761,10 +4908,22 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
||||||
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
||||||
|
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
||||||
|
#else
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
|
||||||
|
#endif
|
||||||
|
|
||||||
sum += d * isum;
|
sum += d * isum;
|
||||||
|
|
||||||
@ -5069,8 +5228,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
uint32_t utmp[4];
|
uint32_t utmp[4];
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_int8x16x2_t q4bytes;
|
ggml_int8x16x2_t q4bytes;
|
||||||
ggml_int8x16x2_t q8bytes;
|
ggml_int8x16x2_t q8bytes;
|
||||||
@ -5107,22 +5269,44 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
int32_t sumi2 = 0;
|
int32_t sumi2 = 0;
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/64; ++j) {
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
|
||||||
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
|
||||||
|
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
|
||||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
||||||
|
|
||||||
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
|
|
||||||
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
||||||
|
#else
|
||||||
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
|
||||||
|
|
||||||
|
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
|
||||||
|
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf += d * (sumi1 + sumi2);
|
sumf += d * (sumi1 + sumi2);
|
||||||
@ -5419,9 +5603,12 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
|
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
@ -5449,20 +5636,41 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
|
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
|
||||||
|
|
||||||
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
q8bytes = ggml_vld1q_s8_x4(q8);
|
q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
|
||||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
||||||
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
|
||||||
|
|
||||||
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
|
||||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
|
||||||
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
|
||||||
|
|
||||||
|
#else
|
||||||
|
q8bytes = ggml_vld1q_s8_x4(q8);
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
|
||||||
|
|
||||||
|
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
||||||
|
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
|
||||||
|
|
||||||
|
#endif
|
||||||
sumf += d * (sumi1 + sumi2);
|
sumf += d * (sumi1 + sumi2);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf - sum_mins;
|
*s = sumf - sum_mins;
|
||||||
@ -5667,11 +5875,15 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
uint32_t utmp[4];
|
uint32_t utmp[4];
|
||||||
|
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
const uint8x16_t mone = vdupq_n_u8(1);
|
const uint8x16_t mone = vdupq_n_u8(1);
|
||||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_int8x16x4_t q5bytes;
|
ggml_int8x16x4_t q5bytes;
|
||||||
|
|
||||||
@ -5726,11 +5938,28 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
||||||
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
||||||
|
|
||||||
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
|
||||||
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
||||||
|
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
||||||
|
#else
|
||||||
|
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
|
||||||
|
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf += d * sumi - dmin * sumi_mins;
|
sumf += d * sumi - dmin * sumi_mins;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
@ -6082,9 +6311,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
const uint8x16_t mh = vdupq_n_u8(16);
|
const uint8x16_t mh = vdupq_n_u8(16);
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_int8x16x4_t q5bytes;
|
ggml_int8x16x4_t q5bytes;
|
||||||
ggml_uint8x16x4_t q5h;
|
ggml_uint8x16x4_t q5h;
|
||||||
@ -6116,12 +6348,32 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
|
||||||
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
|
||||||
|
|
||||||
int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
|
||||||
int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
|
||||||
int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
|
||||||
|
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
|
||||||
|
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
|
||||||
|
|
||||||
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
|
||||||
|
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
|
||||||
|
|
||||||
|
sumf += d*sumi;
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
@ -6348,10 +6600,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
//const int8x16_t m32s = vdupq_n_s8(32);
|
//const int8x16_t m32s = vdupq_n_s8(32);
|
||||||
|
|
||||||
const uint8x16_t mone = vdupq_n_u8(3);
|
const uint8x16_t mone = vdupq_n_u8(3);
|
||||||
@ -6371,7 +6626,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
|
|
||||||
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
|
||||||
const int8x16_t scales = vld1q_s8(scale);
|
const int8x16_t scales = vld1q_s8(scale);
|
||||||
const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
|
const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
|
||||||
|
|
||||||
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
||||||
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
||||||
@ -6403,13 +6658,31 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
||||||
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
||||||
|
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
|
||||||
|
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
||||||
scale += 4;
|
scale += 4;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
|
||||||
|
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
|
||||||
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
||||||
@ -6430,11 +6703,34 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
||||||
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
||||||
|
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
||||||
scale += 4;
|
scale += 4;
|
||||||
|
|
||||||
|
//for (int l = 0; l < 4; ++l) {
|
||||||
|
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
|
||||||
|
// isum += vaddvq_s32(p) * *scale++;
|
||||||
|
//}
|
||||||
|
#else
|
||||||
|
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
|
||||||
|
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
|
||||||
|
scale += 2;
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
//sum += isum * d_all * y[i].d;
|
//sum += isum * d_all * y[i].d;
|
||||||
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
||||||
@ -6780,11 +7076,14 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
||||||
const int8x16_t m32s = vdupq_n_s8(32);
|
const int8x16_t m32s = vdupq_n_s8(32);
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint8x16_t mone = vdupq_n_u8(3);
|
const uint8x16_t mone = vdupq_n_u8(3);
|
||||||
|
|
||||||
@ -6820,10 +7119,26 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
|
||||||
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
|
||||||
|
|
||||||
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
||||||
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
||||||
|
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
||||||
|
#else
|
||||||
|
|
||||||
|
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
|
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
|
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
|
||||||
|
|
||||||
|
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
|
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
|
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
|
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
||||||
|
#endif
|
||||||
|
|
||||||
sum += isum * d_all * y[i].d;
|
sum += isum * d_all * y[i].d;
|
||||||
|
|
||||||
|
135
ggml.h
@ -215,9 +215,9 @@
|
|||||||
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
||||||
|
|
||||||
#define GGML_MAX_DIMS 4
|
#define GGML_MAX_DIMS 4
|
||||||
#define GGML_MAX_PARAMS 2048
|
#define GGML_MAX_PARAMS 1024
|
||||||
#define GGML_MAX_CONTEXTS 64
|
#define GGML_MAX_CONTEXTS 64
|
||||||
#define GGML_MAX_SRC 10
|
#define GGML_MAX_SRC 6
|
||||||
#define GGML_MAX_NAME 64
|
#define GGML_MAX_NAME 64
|
||||||
#define GGML_MAX_OP_PARAMS 64
|
#define GGML_MAX_OP_PARAMS 64
|
||||||
#define GGML_DEFAULT_N_THREADS 4
|
#define GGML_DEFAULT_N_THREADS 4
|
||||||
@ -244,10 +244,11 @@
|
|||||||
#define GGML_ASSERT(x) \
|
#define GGML_ASSERT(x) \
|
||||||
do { \
|
do { \
|
||||||
if (!(x)) { \
|
if (!(x)) { \
|
||||||
fflush(stdout); \
|
|
||||||
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||||
|
fflush(stderr); \
|
||||||
|
fflush(stdout); \
|
||||||
ggml_print_backtrace(); \
|
ggml_print_backtrace(); \
|
||||||
abort(); \
|
exit(1); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
@ -255,8 +256,6 @@
|
|||||||
#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
|
#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
|
||||||
#elif defined(__GNUC__)
|
#elif defined(__GNUC__)
|
||||||
#define GGML_UNREACHABLE() __builtin_unreachable()
|
#define GGML_UNREACHABLE() __builtin_unreachable()
|
||||||
#elif defined(_MSC_VER)
|
|
||||||
#define GGML_UNREACHABLE() __assume(0)
|
|
||||||
#else
|
#else
|
||||||
#define GGML_UNREACHABLE() ((void) 0)
|
#define GGML_UNREACHABLE() ((void) 0)
|
||||||
#endif
|
#endif
|
||||||
@ -285,27 +284,13 @@
|
|||||||
const type prefix##3 = (pointer)->array[3]; \
|
const type prefix##3 = (pointer)->array[3]; \
|
||||||
GGML_UNUSED(prefix##3);
|
GGML_UNUSED(prefix##3);
|
||||||
|
|
||||||
#define GGML_TENSOR_UNARY_OP_LOCALS \
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
||||||
|
|
||||||
#define GGML_TENSOR_BINARY_OP_LOCALS \
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
|
|
||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
|
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__ARM_NEON) && defined(__CUDACC__)
|
#if defined(__ARM_NEON) && defined(__CUDACC__)
|
||||||
typedef half ggml_fp16_t;
|
typedef half ggml_fp16_t;
|
||||||
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
#elif defined(__ARM_NEON)
|
||||||
typedef __fp16 ggml_fp16_t;
|
typedef __fp16 ggml_fp16_t;
|
||||||
#else
|
#else
|
||||||
typedef uint16_t ggml_fp16_t;
|
typedef uint16_t ggml_fp16_t;
|
||||||
@ -345,12 +330,6 @@ extern "C" {
|
|||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
// precision
|
|
||||||
enum ggml_prec {
|
|
||||||
GGML_PREC_DEFAULT,
|
|
||||||
GGML_PREC_F32,
|
|
||||||
};
|
|
||||||
|
|
||||||
enum ggml_backend_type {
|
enum ggml_backend_type {
|
||||||
GGML_BACKEND_CPU = 0,
|
GGML_BACKEND_CPU = 0,
|
||||||
GGML_BACKEND_GPU = 10,
|
GGML_BACKEND_GPU = 10,
|
||||||
@ -403,7 +382,6 @@ extern "C" {
|
|||||||
GGML_OP_GROUP_NORM,
|
GGML_OP_GROUP_NORM,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
GGML_OP_MUL_MAT_ID,
|
|
||||||
GGML_OP_OUT_PROD,
|
GGML_OP_OUT_PROD,
|
||||||
|
|
||||||
GGML_OP_SCALE,
|
GGML_OP_SCALE,
|
||||||
@ -430,10 +408,8 @@ extern "C" {
|
|||||||
GGML_OP_CONV_TRANSPOSE_2D,
|
GGML_OP_CONV_TRANSPOSE_2D,
|
||||||
GGML_OP_POOL_1D,
|
GGML_OP_POOL_1D,
|
||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
|
|
||||||
GGML_OP_UPSCALE, // nearest interpolate
|
GGML_OP_UPSCALE, // nearest interpolate
|
||||||
GGML_OP_PAD,
|
|
||||||
GGML_OP_ARGSORT,
|
|
||||||
GGML_OP_LEAKY_RELU,
|
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
@ -473,8 +449,7 @@ extern "C" {
|
|||||||
GGML_UNARY_OP_GELU,
|
GGML_UNARY_OP_GELU,
|
||||||
GGML_UNARY_OP_GELU_QUICK,
|
GGML_UNARY_OP_GELU_QUICK,
|
||||||
GGML_UNARY_OP_SILU,
|
GGML_UNARY_OP_SILU,
|
||||||
|
GGML_UNARY_OP_LEAKY
|
||||||
GGML_UNARY_OP_COUNT,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
enum ggml_object_type {
|
enum ggml_object_type {
|
||||||
@ -486,8 +461,7 @@ extern "C" {
|
|||||||
enum ggml_log_level {
|
enum ggml_log_level {
|
||||||
GGML_LOG_LEVEL_ERROR = 2,
|
GGML_LOG_LEVEL_ERROR = 2,
|
||||||
GGML_LOG_LEVEL_WARN = 3,
|
GGML_LOG_LEVEL_WARN = 3,
|
||||||
GGML_LOG_LEVEL_INFO = 4,
|
GGML_LOG_LEVEL_INFO = 4
|
||||||
GGML_LOG_LEVEL_DEBUG = 5
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// ggml object
|
// ggml object
|
||||||
@ -511,6 +485,7 @@ extern "C" {
|
|||||||
|
|
||||||
struct ggml_backend_buffer * buffer;
|
struct ggml_backend_buffer * buffer;
|
||||||
|
|
||||||
|
int n_dims;
|
||||||
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
||||||
size_t nb[GGML_MAX_DIMS]; // stride in bytes:
|
size_t nb[GGML_MAX_DIMS]; // stride in bytes:
|
||||||
// nb[0] = ggml_type_size(type)
|
// nb[0] = ggml_type_size(type)
|
||||||
@ -542,7 +517,7 @@ extern "C" {
|
|||||||
|
|
||||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||||
|
|
||||||
char padding[8];
|
char padding[12];
|
||||||
};
|
};
|
||||||
|
|
||||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||||
@ -647,22 +622,16 @@ extern "C" {
|
|||||||
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
||||||
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
||||||
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
|
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
|
||||||
|
GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
|
||||||
|
|
||||||
GGML_API int ggml_blck_size (enum ggml_type type);
|
GGML_API int ggml_blck_size (enum ggml_type type);
|
||||||
GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
|
GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
|
||||||
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
|
||||||
|
|
||||||
GGML_DEPRECATED(
|
|
||||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
|
||||||
"use ggml_row_size() instead");
|
|
||||||
|
|
||||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||||
|
|
||||||
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
|
|
||||||
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
|
|
||||||
|
|
||||||
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API bool ggml_is_quantized(enum ggml_type type);
|
GGML_API bool ggml_is_quantized(enum ggml_type type);
|
||||||
@ -673,11 +642,6 @@ extern "C" {
|
|||||||
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
|
||||||
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
|
||||||
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
|
||||||
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
|
|
||||||
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
|
|
||||||
|
|
||||||
GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
|
|
||||||
@ -738,8 +702,8 @@ extern "C" {
|
|||||||
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
|
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
|
||||||
|
|
||||||
// Context tensor enumeration and lookup
|
// Context tensor enumeration and lookup
|
||||||
GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
|
GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
|
||||||
GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
|
GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
|
||||||
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
|
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
||||||
@ -810,9 +774,6 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// dst = a
|
|
||||||
// view(dst, nb1, nb2, nb3, offset) += b
|
|
||||||
// return dst
|
|
||||||
GGML_API struct ggml_tensor * ggml_acc(
|
GGML_API struct ggml_tensor * ggml_acc(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@ -977,14 +938,15 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_leaky_relu(
|
GGML_API struct ggml_tensor * ggml_leaky(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a, float negative_slope, bool inplace);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// TODO: double-check this computation is correct
|
||||||
GGML_API struct ggml_tensor * ggml_gelu(
|
GGML_API struct ggml_tensor * ggml_gelu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
@ -1066,22 +1028,6 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// change the precision of a matrix multiplication
|
|
||||||
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
|
|
||||||
GGML_API void ggml_mul_mat_set_prec(
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
enum ggml_prec prec);
|
|
||||||
|
|
||||||
// indirect matrix multiplication
|
|
||||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
|
||||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * const as[],
|
|
||||||
int n_as,
|
|
||||||
struct ggml_tensor * ids,
|
|
||||||
int id,
|
|
||||||
struct ggml_tensor * b);
|
|
||||||
|
|
||||||
// A: m columns, n rows,
|
// A: m columns, n rows,
|
||||||
// B: p columns, n rows,
|
// B: p columns, n rows,
|
||||||
// result is m columns, p rows
|
// result is m columns, p rows
|
||||||
@ -1097,13 +1043,13 @@ extern "C" {
|
|||||||
GGML_API struct ggml_tensor * ggml_scale(
|
GGML_API struct ggml_tensor * ggml_scale(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
float s);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
GGML_API struct ggml_tensor * ggml_scale_inplace(
|
GGML_API struct ggml_tensor * ggml_scale_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
float s);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// b -> view(a,offset,nb1,nb2,3), return modified a
|
// b -> view(a,offset,nb1,nb2,3), return modified a
|
||||||
GGML_API struct ggml_tensor * ggml_set(
|
GGML_API struct ggml_tensor * ggml_set(
|
||||||
@ -1289,7 +1235,6 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// supports 3D: a->ne[2] == b->ne[1]
|
|
||||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@ -1338,14 +1283,6 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// fused soft_max(a*scale + mask)
|
|
||||||
// mask is optional
|
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
struct ggml_tensor * mask,
|
|
||||||
float scale);
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@ -1576,32 +1513,6 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int scale_factor);
|
int scale_factor);
|
||||||
|
|
||||||
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
|
||||||
GGML_API struct ggml_tensor * ggml_pad(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
int p0,
|
|
||||||
int p1,
|
|
||||||
int p2,
|
|
||||||
int p3);
|
|
||||||
|
|
||||||
// sort rows
|
|
||||||
enum ggml_sort_order {
|
|
||||||
GGML_SORT_ASC,
|
|
||||||
GGML_SORT_DESC,
|
|
||||||
};
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_argsort(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
enum ggml_sort_order order);
|
|
||||||
|
|
||||||
// top k elements per row
|
|
||||||
GGML_API struct ggml_tensor * ggml_top_k(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
int k);
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn(
|
GGML_API struct ggml_tensor * ggml_flash_attn(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
@ -1663,6 +1574,7 @@ extern "C" {
|
|||||||
int kh);
|
int kh);
|
||||||
|
|
||||||
// used in sam
|
// used in sam
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_add_rel_pos(
|
GGML_API struct ggml_tensor * ggml_add_rel_pos(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
@ -1837,7 +1749,7 @@ extern "C" {
|
|||||||
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
|
||||||
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
|
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
|
||||||
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
|
||||||
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
|
GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
|
||||||
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
||||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
|
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
|
||||||
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
||||||
@ -2133,7 +2045,6 @@ extern "C" {
|
|||||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
|
|
||||||
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
||||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
||||||
@ -2142,7 +2053,6 @@ extern "C" {
|
|||||||
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
|
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
|
||||||
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
|
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
|
||||||
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
|
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
|
||||||
GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i);
|
|
||||||
|
|
||||||
// overrides existing values or adds a new one
|
// overrides existing values or adds a new one
|
||||||
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
||||||
@ -2198,7 +2108,6 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
|
|
||||||
GGML_API int ggml_cpu_has_avx (void);
|
GGML_API int ggml_cpu_has_avx (void);
|
||||||
GGML_API int ggml_cpu_has_avx_vnni (void);
|
|
||||||
GGML_API int ggml_cpu_has_avx2 (void);
|
GGML_API int ggml_cpu_has_avx2 (void);
|
||||||
GGML_API int ggml_cpu_has_avx512 (void);
|
GGML_API int ggml_cpu_has_avx512 (void);
|
||||||
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|
||||||
|
@ -19,7 +19,7 @@ function get_script_path() {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
models_path="${2:-$(get_script_path)}"
|
models_path="$(get_script_path)"
|
||||||
|
|
||||||
# Whisper models
|
# Whisper models
|
||||||
models=(
|
models=(
|
||||||
@ -43,7 +43,7 @@ models=(
|
|||||||
"large-v1"
|
"large-v1"
|
||||||
"large-v2"
|
"large-v2"
|
||||||
"large-v3"
|
"large-v3"
|
||||||
"large-v3-q5_0"
|
"large-q5_0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# list available models
|
# list available models
|
||||||
@ -56,8 +56,8 @@ function list_models {
|
|||||||
printf "\n\n"
|
printf "\n\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ "$#" -lt 1 ] || [ "$#" -gt 2 ]; then
|
if [ "$#" -ne 1 ]; then
|
||||||
printf "Usage: $0 <model> [models_path]\n"
|
printf "Usage: $0 <model>\n"
|
||||||
list_models
|
list_models
|
||||||
|
|
||||||
exit 1
|
exit 1
|
||||||
@ -105,7 +105,7 @@ if [ $? -ne 0 ]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Done! Model '$model' saved in '$models_path/ggml-$model.bin'\n"
|
printf "Done! Model '$model' saved in 'models/ggml-$model.bin'\n"
|
||||||
printf "You can now use it like this:\n\n"
|
printf "You can now use it like this:\n\n"
|
||||||
printf " $ ./main -m $models_path/ggml-$model.bin -f samples/jfk.wav\n"
|
printf " $ ./main -m models/ggml-$model.bin -f samples/jfk.wav\n"
|
||||||
printf "\n"
|
printf "\n"
|
||||||
|
@ -64,15 +64,15 @@ int whisper_openvino_encode(
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_n_dims(mel) != 2) {
|
if (mel->n_dims != 2) {
|
||||||
fprintf(stderr, "%s: Error! mel ggml_tensor expected to have n_dims=2, but it has n_dims=%d\n",
|
fprintf(stderr, "%s: Error! mel ggml_tensor expected to have n_dims=2, but it has n_dims=%d\n",
|
||||||
__func__, ggml_n_dims(mel));
|
__func__, mel->n_dims);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_n_dims(out) != 2) {
|
if (out->n_dims != 2) {
|
||||||
fprintf(stderr, "%s: Error! out ggml_tensor expected to have n_dims=2, but it has n_dims=%d\n",
|
fprintf(stderr, "%s: Error! out ggml_tensor expected to have n_dims=2, but it has n_dims=%d\n",
|
||||||
__func__, ggml_n_dims(out));
|
__func__, out->n_dims);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
199
whisper.cpp
@ -122,18 +122,9 @@ WHISPER_ATTRIBUTE_FORMAT(2, 3)
|
|||||||
static void whisper_log_internal (ggml_log_level level, const char * format, ...);
|
static void whisper_log_internal (ggml_log_level level, const char * format, ...);
|
||||||
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
||||||
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
|
||||||
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
|
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
// define this to enable verbose trace logging - useful for debugging purposes
|
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
// #define WHISPER_DEBUG
|
|
||||||
|
|
||||||
#if defined(WHISPER_DEBUG)
|
|
||||||
#define WHISPER_LOG_DEBUG(...) whisper_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
|
||||||
#else
|
|
||||||
#define WHISPER_LOG_DEBUG(...)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define WHISPER_ASSERT(x) \
|
#define WHISPER_ASSERT(x) \
|
||||||
do { \
|
do { \
|
||||||
@ -143,6 +134,18 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
// define this to enable verbose trace logging - useful for debugging purposes
|
||||||
|
//#define WHISPER_DEBUG
|
||||||
|
|
||||||
|
#if defined(WHISPER_DEBUG)
|
||||||
|
#define WHISPER_PRINT_DEBUG(...) \
|
||||||
|
do { \
|
||||||
|
fprintf(stderr, __VA_ARGS__); \
|
||||||
|
} while (0)
|
||||||
|
#else
|
||||||
|
#define WHISPER_PRINT_DEBUG(...)
|
||||||
|
#endif
|
||||||
|
|
||||||
//#define WHISPER_USE_FLASH_ATTN
|
//#define WHISPER_USE_FLASH_ATTN
|
||||||
//#define WHISPER_USE_FLASH_FF
|
//#define WHISPER_USE_FLASH_FF
|
||||||
#define WHISPER_MAX_DECODERS 8
|
#define WHISPER_MAX_DECODERS 8
|
||||||
@ -152,7 +155,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
// ggml helpers
|
// ggml helpers
|
||||||
//
|
//
|
||||||
|
|
||||||
static bool ggml_graph_compute_helper(
|
static void ggml_graph_compute_helper(
|
||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
std::vector<uint8_t> & buf,
|
std::vector<uint8_t> & buf,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
@ -168,10 +171,10 @@ static bool ggml_graph_compute_helper(
|
|||||||
plan.work_data = buf.data();
|
plan.work_data = buf.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_graph_compute(graph, &plan);
|
ggml_graph_compute(graph, &plan);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_graph_compute_helper(
|
static void ggml_graph_compute_helper(
|
||||||
struct ggml_backend * backend,
|
struct ggml_backend * backend,
|
||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
@ -183,7 +186,7 @@ static bool ggml_graph_compute_helper(
|
|||||||
ggml_backend_metal_set_n_cb(backend, n_threads);
|
ggml_backend_metal_set_n_cb(backend, n_threads);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
return ggml_backend_graph_compute(backend, graph);
|
ggml_backend_graph_compute(backend, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
||||||
@ -1060,7 +1063,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
|||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
if (params.use_gpu && ggml_cublas_loaded()) {
|
if (params.use_gpu && ggml_cublas_loaded()) {
|
||||||
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
||||||
backend_gpu = ggml_backend_cuda_init(0);
|
backend_gpu = ggml_backend_cuda_init();
|
||||||
if (!backend_gpu) {
|
if (!backend_gpu) {
|
||||||
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
|
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||||
}
|
}
|
||||||
@ -1074,10 +1077,6 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
|
|||||||
backend_gpu = ggml_backend_metal_init();
|
backend_gpu = ggml_backend_metal_init();
|
||||||
if (!backend_gpu) {
|
if (!backend_gpu) {
|
||||||
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
|
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
|
||||||
} else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
|
|
||||||
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
|
|
||||||
ggml_backend_free(backend_gpu);
|
|
||||||
backend_gpu = NULL;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -1342,10 +1341,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
||||||
|
|
||||||
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
||||||
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
|
||||||
|
|
||||||
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
||||||
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
|
||||||
|
|
||||||
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
||||||
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
||||||
@ -1575,6 +1574,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
|
|
||||||
auto tensor = model.tensors[name.data()];
|
auto tensor = model.tensors[name.data()];
|
||||||
|
|
||||||
|
const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
|
||||||
|
|
||||||
|
if (!is_conv_bias) {
|
||||||
if (ggml_nelements(tensor) != nelements) {
|
if (ggml_nelements(tensor) != nelements) {
|
||||||
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||||
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
|
||||||
@ -1595,6 +1597,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_t backend = wctx.backend;
|
ggml_backend_t backend = wctx.backend;
|
||||||
|
|
||||||
@ -1604,7 +1607,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
|| ggml_backend_is_metal(backend)
|
|| ggml_backend_is_metal(backend)
|
||||||
#endif
|
#endif
|
||||||
)) {
|
) && !is_conv_bias) {
|
||||||
// for the CPU and Metal backend, we can read directly into the tensor
|
// for the CPU and Metal backend, we can read directly into the tensor
|
||||||
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
||||||
BYTESWAP_TENSOR(tensor);
|
BYTESWAP_TENSOR(tensor);
|
||||||
@ -1612,7 +1615,24 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
// read into a temporary buffer first, then copy to device memory
|
// read into a temporary buffer first, then copy to device memory
|
||||||
read_buf.resize(ggml_nbytes(tensor));
|
read_buf.resize(ggml_nbytes(tensor));
|
||||||
|
|
||||||
|
// we repeat the 2 bias tensors along dim 0:
|
||||||
|
// [1, 512] -> [3000, 512] (conv1.bias)
|
||||||
|
// [1, 512] -> [1500, 512] (conv2.bias)
|
||||||
|
if (is_conv_bias) {
|
||||||
|
loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
|
||||||
|
|
||||||
|
float * data_f32 = (float *) read_buf.data();
|
||||||
|
for (int64_t y = 0; y < tensor->ne[1]; ++y) {
|
||||||
|
const int64_t yy = tensor->ne[1] - y - 1;
|
||||||
|
const float val = data_f32[yy];
|
||||||
|
|
||||||
|
for (int64_t x = 0; x < tensor->ne[0]; ++x) {
|
||||||
|
data_f32[yy*tensor->ne[0] + x] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
loader->read(loader->context, read_buf.data(), read_buf.size());
|
loader->read(loader->context, read_buf.data(), read_buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
||||||
}
|
}
|
||||||
@ -1712,12 +1732,20 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|||||||
// convolution + gelu
|
// convolution + gelu
|
||||||
{
|
{
|
||||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
||||||
|
if (n_ctx == hparams.n_audio_ctx) {
|
||||||
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||||
|
} else {
|
||||||
|
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_1_b, cur->ne[0], cur->ne[1], model.e_conv_1_b->nb[1], 0)));
|
||||||
|
}
|
||||||
|
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
||||||
|
if (n_ctx == hparams.n_audio_ctx) {
|
||||||
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||||
|
} else {
|
||||||
|
cur = ggml_add(ctx0, cur, ggml_cont(ctx0, ggml_view_2d(ctx0, model.e_conv_2_b, cur->ne[0], cur->ne[1], model.e_conv_2_b->nb[1], 0)));
|
||||||
|
}
|
||||||
|
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
}
|
}
|
||||||
@ -1774,7 +1802,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
|
||||||
|
|
||||||
//ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
ggml_allocr * alloc = wstate.alloc_encode.alloc;
|
||||||
|
|
||||||
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
|
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
|
||||||
//ggml_allocr_alloc(alloc, cur);
|
//ggml_allocr_alloc(alloc, cur);
|
||||||
@ -1784,7 +1812,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
//}
|
//}
|
||||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
||||||
|
|
||||||
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
|
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
ggml_allocr_alloc(alloc, KQscale);
|
||||||
|
|
||||||
|
if (!ggml_allocr_is_measure(alloc)) {
|
||||||
|
const float val = 1.0f/sqrtf(float(n_state)/n_head);
|
||||||
|
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
// ===================================================================
|
// ===================================================================
|
||||||
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
||||||
@ -1834,14 +1868,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
||||||
|
|
||||||
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
|
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
// note: no bias for Key
|
// note: no bias for Key
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_k_w,
|
layer.attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
|
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||||
|
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_v_w,
|
layer.attn_v_w,
|
||||||
@ -2023,7 +2057,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
//ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
ggml_allocr * alloc = wstate.alloc_cross.alloc;
|
||||||
|
|
||||||
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||||
//ggml_allocr_alloc(alloc, cur);
|
//ggml_allocr_alloc(alloc, cur);
|
||||||
@ -2033,7 +2067,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||||||
//}
|
//}
|
||||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
||||||
|
|
||||||
const float Kscale = pow(float(n_state) / n_head, -0.25);
|
struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
ggml_allocr_alloc(alloc, Kscale);
|
||||||
|
|
||||||
|
if (!ggml_allocr_is_measure(alloc)) {
|
||||||
|
const float val = pow(float(n_state) / n_head, -0.25);
|
||||||
|
ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||||
auto & layer = model.layers_decoder[il];
|
auto & layer = model.layers_decoder[il];
|
||||||
@ -2103,9 +2143,7 @@ static bool whisper_encode_internal(
|
|||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
if (!whisper_encode_external(wstate)) {
|
if (!whisper_encode_external(wstate)) {
|
||||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2119,9 +2157,7 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cross
|
// cross
|
||||||
@ -2134,9 +2170,7 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
||||||
@ -2169,7 +2203,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
|
const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n;
|
||||||
const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
|
const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||||
|
|
||||||
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
//WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
/*.mem_size =*/ wstate.alloc_decode.meta.size(),
|
||||||
@ -2198,7 +2232,13 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const float KQscale = pow(float(n_state)/n_head, -0.25);
|
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||||
|
ggml_allocr_alloc(alloc, KQscale);
|
||||||
|
|
||||||
|
if (!ggml_allocr_is_measure(alloc)) {
|
||||||
|
const float val = pow(float(n_state)/n_head, -0.25);
|
||||||
|
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
ggml_allocr_alloc(alloc, KQ_mask);
|
ggml_allocr_alloc(alloc, KQ_mask);
|
||||||
@ -2558,9 +2598,7 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
logits = gf->nodes[gf->n_nodes - 1];
|
logits = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logits_out.resize(n_tokens*n_vocab);
|
logits_out.resize(n_tokens*n_vocab);
|
||||||
@ -3555,17 +3593,6 @@ const char * whisper_lang_str(int id) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * whisper_lang_str_full(int id) {
|
|
||||||
for (const auto & kv : g_lang) {
|
|
||||||
if (kv.second.first == id) {
|
|
||||||
return kv.second.second.c_str();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int whisper_lang_auto_detect_with_state(
|
int whisper_lang_auto_detect_with_state(
|
||||||
struct whisper_context * ctx,
|
struct whisper_context * ctx,
|
||||||
struct whisper_state * state,
|
struct whisper_state * state,
|
||||||
@ -4949,7 +4976,7 @@ static void whisper_sequence_score(
|
|||||||
const auto p = kv.second/(double)cnt;
|
const auto p = kv.second/(double)cnt;
|
||||||
entropy -= p*log(p);
|
entropy -= p*log(p);
|
||||||
|
|
||||||
//WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
|
//WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
sequence.entropy = entropy;
|
sequence.entropy = entropy;
|
||||||
@ -5015,7 +5042,6 @@ int whisper_full_with_state(
|
|||||||
// basically don't process anything that is less than 1.0s
|
// basically don't process anything that is less than 1.0s
|
||||||
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
|
||||||
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
|
||||||
WHISPER_LOG_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5154,7 +5180,7 @@ int whisper_full_with_state(
|
|||||||
ctx, state, progress_cur, params.progress_callback_user_data);
|
ctx, state, progress_cur, params.progress_callback_user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
// if only 1 second left, then stop
|
// of only 1 second left, then stop
|
||||||
if (seek + 100 >= seek_end) {
|
if (seek + 100 >= seek_end) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -5204,7 +5230,7 @@ int whisper_full_with_state(
|
|||||||
|
|
||||||
n_decoders_cur = std::max(1, n_decoders_cur);
|
n_decoders_cur = std::max(1, n_decoders_cur);
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
|
WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
|
||||||
|
|
||||||
// TAGS: WHISPER_DECODER_INIT
|
// TAGS: WHISPER_DECODER_INIT
|
||||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||||
@ -5248,11 +5274,11 @@ int whisper_full_with_state(
|
|||||||
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
|
||||||
|
|
||||||
// print the prompt
|
// print the prompt
|
||||||
WHISPER_LOG_DEBUG("\n\n");
|
WHISPER_PRINT_DEBUG("\n\n");
|
||||||
for (int i = 0; i < (int) prompt.size(); i++) {
|
for (int i = 0; i < (int) prompt.size(); i++) {
|
||||||
WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
|
WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
|
||||||
}
|
}
|
||||||
WHISPER_LOG_DEBUG("\n\n");
|
WHISPER_PRINT_DEBUG("\n\n");
|
||||||
|
|
||||||
whisper_kv_cache_clear(state->kv_self);
|
whisper_kv_cache_clear(state->kv_self);
|
||||||
|
|
||||||
@ -5400,7 +5426,7 @@ int whisper_full_with_state(
|
|||||||
|
|
||||||
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
|
||||||
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5443,7 +5469,6 @@ int whisper_full_with_state(
|
|||||||
|
|
||||||
// do not allow to go back in time
|
// do not allow to go back in time
|
||||||
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
|
||||||
WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
|
|
||||||
failed = true; // TODO: maybe this is not a failure ?
|
failed = true; // TODO: maybe this is not a failure ?
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -5458,7 +5483,7 @@ int whisper_full_with_state(
|
|||||||
#ifdef WHISPER_DEBUG
|
#ifdef WHISPER_DEBUG
|
||||||
{
|
{
|
||||||
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
|
||||||
WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
|
||||||
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -5472,7 +5497,6 @@ int whisper_full_with_state(
|
|||||||
if (seek + seek_delta + 100 >= seek_end) {
|
if (seek + seek_delta + 100 >= seek_end) {
|
||||||
result_len = i + 1;
|
result_len = i + 1;
|
||||||
} else {
|
} else {
|
||||||
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
|
|
||||||
failed = true;
|
failed = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -5483,7 +5507,6 @@ int whisper_full_with_state(
|
|||||||
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
seek_delta = 100*WHISPER_CHUNK_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j);
|
|
||||||
completed = true;
|
completed = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -5499,7 +5522,6 @@ int whisper_full_with_state(
|
|||||||
// sometimes, the decoding can get stuck in a repetition loop
|
// sometimes, the decoding can get stuck in a repetition loop
|
||||||
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
|
||||||
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
|
||||||
WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
|
|
||||||
failed = true;
|
failed = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -5541,7 +5563,7 @@ int whisper_full_with_state(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
//WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
|
||||||
|
|
||||||
decoder.i_batch = batch.n_tokens;
|
decoder.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
@ -5621,11 +5643,11 @@ int whisper_full_with_state(
|
|||||||
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
decoder.sequence.tokens.resize(decoder.sequence.result_len);
|
||||||
whisper_sequence_score(params, decoder.sequence);
|
whisper_sequence_score(params, decoder.sequence);
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
|
||||||
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
|
||||||
|
|
||||||
if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
|
if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
|
||||||
WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
|
||||||
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
__func__, j, decoder.sequence.entropy, params.entropy_thold);
|
||||||
|
|
||||||
decoder.failed = true;
|
decoder.failed = true;
|
||||||
@ -5640,33 +5662,34 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool success = true;
|
|
||||||
|
|
||||||
// was the decoding successful for the current temperature?
|
// was the decoding successful for the current temperature?
|
||||||
// do fallback only if:
|
// do fallback only if:
|
||||||
// - we are not at the last temperature
|
// - we are not at the last temperature
|
||||||
if (it != (int) temperatures.size() - 1) {
|
// - we are not at the end of the audio (3 sec)
|
||||||
|
if (it != (int) temperatures.size() - 1 &&
|
||||||
|
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
|
||||||
|
bool success = true;
|
||||||
|
|
||||||
const auto & decoder = state->decoders[best_decoder_id];
|
const auto & decoder = state->decoders[best_decoder_id];
|
||||||
|
|
||||||
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
||||||
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
|
|
||||||
success = false;
|
success = false;
|
||||||
state->n_fail_p++;
|
state->n_fail_p++;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (success) {
|
if (success) {
|
||||||
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
|
||||||
// WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
|
||||||
//}
|
//}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output results through a user-provided callback
|
// output results through a user-provided callback
|
||||||
@ -5678,7 +5701,7 @@ int whisper_full_with_state(
|
|||||||
|
|
||||||
const auto & tokens_cur = best_decoder.sequence.tokens;
|
const auto & tokens_cur = best_decoder.sequence.tokens;
|
||||||
|
|
||||||
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
//WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
||||||
|
|
||||||
// update prompt_past
|
// update prompt_past
|
||||||
prompt_past.clear();
|
prompt_past.clear();
|
||||||
@ -5798,7 +5821,7 @@ int whisper_full_with_state(
|
|||||||
// update audio window
|
// update audio window
|
||||||
seek += seek_delta;
|
seek += seek_delta;
|
||||||
|
|
||||||
WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
|
WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6115,7 +6138,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|||||||
|
|
||||||
// multi-thread
|
// multi-thread
|
||||||
|
|
||||||
for (int32_t k = 1; k <= n_threads; k++) {
|
for (uint32_t n_threads = 1; n_threads <= std::thread::hardware_concurrency(); n_threads++) {
|
||||||
char * src = (char *) malloc(size);
|
char * src = (char *) malloc(size);
|
||||||
char * dst = (char *) malloc(size);
|
char * dst = (char *) malloc(size);
|
||||||
|
|
||||||
@ -6126,8 +6149,8 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|||||||
double tsum = 0.0;
|
double tsum = 0.0;
|
||||||
|
|
||||||
auto helper = [&](int th) {
|
auto helper = [&](int th) {
|
||||||
const int64_t i0 = (th + 0)*size/k;
|
const int64_t i0 = (th + 0)*size/n_threads;
|
||||||
const int64_t i1 = (th + 1)*size/k;
|
const int64_t i1 = (th + 1)*size/n_threads;
|
||||||
|
|
||||||
for (size_t i = 0; i < n; i++) {
|
for (size_t i = 0; i < n; i++) {
|
||||||
memcpy(dst + i0, src + i0, i1 - i0);
|
memcpy(dst + i0, src + i0, i1 - i0);
|
||||||
@ -6138,14 +6161,14 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|||||||
|
|
||||||
const int64_t t0 = ggml_time_us();
|
const int64_t t0 = ggml_time_us();
|
||||||
|
|
||||||
std::vector<std::thread> threads(k - 1);
|
std::vector<std::thread> threads(n_threads - 1);
|
||||||
for (int32_t th = 0; th < k - 1; ++th) {
|
for (uint32_t th = 0; th < n_threads - 1; ++th) {
|
||||||
threads[th] = std::thread(helper, th);
|
threads[th] = std::thread(helper, th);
|
||||||
}
|
}
|
||||||
|
|
||||||
helper(k - 1);
|
helper(n_threads - 1);
|
||||||
|
|
||||||
for (int32_t th = 0; th < k - 1; ++th) {
|
for (uint32_t th = 0; th < n_threads - 1; ++th) {
|
||||||
threads[th].join();
|
threads[th].join();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6153,7 +6176,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
|
|||||||
|
|
||||||
tsum += (t1 - t0)*1e-6;
|
tsum += (t1 - t0)*1e-6;
|
||||||
|
|
||||||
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
|
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), n_threads);
|
||||||
s += strbuf;
|
s += strbuf;
|
||||||
|
|
||||||
// needed to prevent the compiler from optimizing the memcpy away
|
// needed to prevent the compiler from optimizing the memcpy away
|
||||||
|
@ -315,9 +315,6 @@ extern "C" {
|
|||||||
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
|
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
|
||||||
WHISPER_API const char * whisper_lang_str(int id);
|
WHISPER_API const char * whisper_lang_str(int id);
|
||||||
|
|
||||||
// Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found
|
|
||||||
WHISPER_API const char * whisper_lang_str_full(int id);
|
|
||||||
|
|
||||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
||||||
// Returns the top language id or negative on failure
|
// Returns the top language id or negative on failure
|
||||||
|