mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-06 01:11:40 +02:00
Compare commits
49 Commits
grammar-de
...
distil-sup
Author | SHA1 | Date | |
---|---|---|---|
673c55c683 | |||
b8c93c5f3b | |||
8a2bee6717 | |||
d445098c8f | |||
74de25158e | |||
bce49a260e | |||
45c87b5481 | |||
dfe4bc6e59 | |||
54c978c3a3 | |||
9a7074d4aa | |||
a0040f5d12 | |||
940cdb1396 | |||
1b775cdd68 | |||
80bf931668 | |||
91c0b23384 | |||
2f668c330e | |||
08fa34882f | |||
4037705531 | |||
c76c11e59c | |||
9edbd0a204 | |||
707507ff6d | |||
7e1592d2cd | |||
903c9579b8 | |||
b440ef8c96 | |||
700f63a806 | |||
951a119926 | |||
1ca4041b86 | |||
80c1512fd5 | |||
0ac9cefd03 | |||
b8432f28f4 | |||
93935980f8 | |||
3fec2119e6 | |||
9b14418863 | |||
6ddc727fac | |||
acb5278cc8 | |||
0839209cab | |||
b39809668a | |||
3e9edc6845 | |||
bfc73f1fa2 | |||
f00c9bba33 | |||
b55b505690 | |||
2818de21ff | |||
aed5d40607 | |||
afa5477d1c | |||
01fcd42431 | |||
f990610776 | |||
64cb45fd79 | |||
ace6c12ec6 | |||
cac75be05b |
28
.devops/cublas.Dockerfile
Normal file
28
.devops/cublas.Dockerfile
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
ARG UBUNTU_VERSION=22.04
|
||||||
|
|
||||||
|
# This needs to generally match the container host's environment.
|
||||||
|
ARG CUDA_VERSION=11.7.1
|
||||||
|
|
||||||
|
# Target the CUDA build image
|
||||||
|
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||||
|
|
||||||
|
FROM ${BASE_CUDA_DEV_CONTAINER} as build
|
||||||
|
|
||||||
|
# Unless otherwise specified, we make a fat build.
|
||||||
|
ARG CUDA_DOCKER_ARCH=all
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y build-essential git cmake
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Set nvcc architecture
|
||||||
|
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||||
|
# Enable cuBLAS
|
||||||
|
ENV WHISPER_CUBLAS=1
|
||||||
|
|
||||||
|
RUN make
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/main"]
|
6
.github/workflows/build.yml
vendored
6
.github/workflows/build.yml
vendored
@ -428,15 +428,15 @@ jobs:
|
|||||||
|
|
||||||
- name: Publish package
|
- name: Publish package
|
||||||
if: ${{ github.ref == 'refs/heads/master' }}
|
if: ${{ github.ref == 'refs/heads/master' }}
|
||||||
uses: gradle/gradle-build-action@v2
|
uses: gradle/gradle-build-action@v2.4.2
|
||||||
with:
|
with:
|
||||||
arguments: publish
|
arguments: publish
|
||||||
build-root-directory: bindings/java
|
build-root-directory: bindings/java
|
||||||
env:
|
env:
|
||||||
MAVEN_USERNAME: ${{ secrets.JIRA_USER }}
|
MAVEN_USERNAME: ${{ secrets.JIRA_USER }}
|
||||||
MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }}
|
MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }}
|
||||||
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
|
PGP_SECRET: ${{ secrets.GPG_PRIVATE_KEY }}
|
||||||
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
|
PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
|
||||||
|
|
||||||
quantize:
|
quantize:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -46,3 +46,5 @@ models/*.mlpackage
|
|||||||
bindings/java/.gradle/
|
bindings/java/.gradle/
|
||||||
bindings/java/.idea/
|
bindings/java/.idea/
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
benchmark_results.csv
|
||||||
|
123
CMakeLists.txt
123
CMakeLists.txt
@ -1,4 +1,4 @@
|
|||||||
cmake_minimum_required (VERSION 3.0)
|
cmake_minimum_required (VERSION 3.5)
|
||||||
|
|
||||||
project(whisper.cpp VERSION 1.4.2)
|
project(whisper.cpp VERSION 1.4.2)
|
||||||
|
|
||||||
@ -35,6 +35,12 @@ endif()
|
|||||||
|
|
||||||
# options
|
# options
|
||||||
|
|
||||||
|
if (APPLE)
|
||||||
|
set(WHISPER_METAL_DEFAULT ON)
|
||||||
|
else()
|
||||||
|
set(WHISPER_METAL_DEFAULT OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||||
|
|
||||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||||
@ -58,6 +64,8 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
|
|||||||
|
|
||||||
if (APPLE)
|
if (APPLE)
|
||||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||||
|
option(WHISPER_METAL "whisper: use Metal" ${WHISPER_METAL_DEFAULT})
|
||||||
|
option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF)
|
||||||
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
||||||
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
||||||
else()
|
else()
|
||||||
@ -109,10 +117,38 @@ if (APPLE)
|
|||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||||
else()
|
else()
|
||||||
message(WARNING "Accelerate framework not found")
|
message(FATAL_ERROR "Accelerate framework not found")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (WHISPER_METAL)
|
||||||
|
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||||
|
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||||
|
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||||
|
|
||||||
|
if (METAL_FRAMEWORK)
|
||||||
|
message(STATUS "Metal framework found")
|
||||||
|
|
||||||
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
|
||||||
|
${FOUNDATION_LIBRARY}
|
||||||
|
${METAL_FRAMEWORK}
|
||||||
|
${METALKIT_FRAMEWORK}
|
||||||
|
)
|
||||||
|
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_METAL)
|
||||||
|
|
||||||
|
if (WHISPER_METAL_NDEBUG)
|
||||||
|
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Metal framework not found")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
|
||||||
|
|
||||||
|
# copy ggml-metal.metal to bin directory
|
||||||
|
configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (WHISPER_COREML)
|
if (WHISPER_COREML)
|
||||||
find_library(FOUNDATION_FRAMEWORK Foundation)
|
find_library(FOUNDATION_FRAMEWORK Foundation)
|
||||||
find_library(COREML_FRAMEWORK CoreML)
|
find_library(COREML_FRAMEWORK CoreML)
|
||||||
@ -122,7 +158,7 @@ if (APPLE)
|
|||||||
|
|
||||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
|
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
|
||||||
else()
|
else()
|
||||||
message(WARNING "CoreML framework not found")
|
message(FATAL_ERROR "CoreML framework not found")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (WHISPER_COREML_ALLOW_FALLBACK)
|
if (WHISPER_COREML_ALLOW_FALLBACK)
|
||||||
@ -145,13 +181,13 @@ if (WHISPER_BLAS)
|
|||||||
include_directories($ENV{OPENBLAS_PATH}/include)
|
include_directories($ENV{OPENBLAS_PATH}/include)
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
||||||
else ()
|
else ()
|
||||||
message(WARNING "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
|
message(FATAL_ERROR "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
|
||||||
endif ()
|
endif ()
|
||||||
else ()
|
else ()
|
||||||
set(BLA_STATIC 1)
|
set(BLA_STATIC 1)
|
||||||
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
|
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
|
||||||
# set(BLA_PREFER_PKGCONFIG 1)
|
|
||||||
set(BLA_SIZEOF_INTEGER 8)
|
set(BLA_SIZEOF_INTEGER 8)
|
||||||
|
set(BLA_PREFER_PKGCONFIG 1)
|
||||||
find_package(BLAS)
|
find_package(BLAS)
|
||||||
|
|
||||||
if(BLAS_FOUND)
|
if(BLAS_FOUND)
|
||||||
@ -162,7 +198,7 @@ if (WHISPER_BLAS)
|
|||||||
include_directories(${BLAS_INCLUDE_DIRS})
|
include_directories(${BLAS_INCLUDE_DIRS})
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
||||||
else()
|
else()
|
||||||
message(WARNING "BLAS library was not found")
|
message(FATAL_ERROR "BLAS library was not found")
|
||||||
endif()
|
endif()
|
||||||
endif ()
|
endif ()
|
||||||
endif ()
|
endif ()
|
||||||
@ -177,7 +213,7 @@ if (WHISPER_CUBLAS)
|
|||||||
|
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
|
|
||||||
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_CUBLAS)
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
|
|
||||||
@ -188,7 +224,7 @@ if (WHISPER_CUBLAS)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
message(WARNING "cuBLAS not found")
|
message(FATAL_ERROR "cuBLAS not found")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -219,7 +255,7 @@ if (WHISPER_HIPBLAS)
|
|||||||
endif()
|
endif()
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
|
||||||
else()
|
else()
|
||||||
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
|
message(FATAL_ERROR "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -228,13 +264,13 @@ if (WHISPER_CLBLAST)
|
|||||||
if (CLBlast_FOUND)
|
if (CLBlast_FOUND)
|
||||||
message(STATUS "CLBlast found")
|
message(STATUS "CLBlast found")
|
||||||
|
|
||||||
set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
|
set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h)
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_CLBLAST)
|
add_compile_definitions(GGML_USE_CLBLAST)
|
||||||
|
|
||||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
|
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
|
||||||
else()
|
else()
|
||||||
message(WARNING "CLBlast not found")
|
message(FATAL_ERROR "CLBlast not found")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -321,6 +357,53 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# POSIX conformance
|
||||||
|
#
|
||||||
|
|
||||||
|
# clock_gettime came in POSIX.1b (1993)
|
||||||
|
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
||||||
|
# posix_memalign came in POSIX.1-2001 / SUSv3
|
||||||
|
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
|
||||||
|
add_compile_definitions(_XOPEN_SOURCE=600)
|
||||||
|
|
||||||
|
# Somehow in OpenBSD whenever POSIX conformance is specified
|
||||||
|
# some string functions rely on locale_t availability,
|
||||||
|
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||||
|
remove_definitions(-D_XOPEN_SOURCE=600)
|
||||||
|
add_compile_definitions(_XOPEN_SOURCE=700)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Data types, macros and functions related to controlling CPU affinity
|
||||||
|
# are available on Linux through GNU extensions in libc
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
add_compile_definitions(_GNU_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
||||||
|
# and on macOS its availability depends on enabling Darwin extensions
|
||||||
|
# similarly on DragonFly, enabling BSD extensions is necessary
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
|
add_compile_definitions(_DARWIN_C_SOURCE)
|
||||||
|
endif()
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "DragonFly")
|
||||||
|
add_compile_definitions(_DARWIN_C_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# alloca is a non-standard interface that is not visible on BSDs when
|
||||||
|
# POSIX conformance is specified, but not all of them provide a clean way
|
||||||
|
# to enable it in such cases
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
|
||||||
|
add_compile_definitions(__BSD_VISIBLE)
|
||||||
|
endif()
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
|
||||||
|
add_compile_definitions(_NETBSD_SOURCE)
|
||||||
|
endif()
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||||
|
add_compile_definitions(_BSD_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (WHISPER_PERF)
|
if (WHISPER_PERF)
|
||||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
|
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
|
||||||
endif()
|
endif()
|
||||||
@ -379,8 +462,11 @@ set(TARGET whisper)
|
|||||||
add_library(${TARGET}
|
add_library(${TARGET}
|
||||||
ggml.h
|
ggml.h
|
||||||
ggml.c
|
ggml.c
|
||||||
${GGML_CUDA_SOURCES}
|
ggml-alloc.h
|
||||||
${GGML_OPENCL_SOURCES}
|
ggml-alloc.c
|
||||||
|
${GGML_SOURCES_METAL}
|
||||||
|
${GGML_SOURCES_CUDA}
|
||||||
|
${GGML_SOURCES_OPENCL}
|
||||||
whisper.h
|
whisper.h
|
||||||
whisper.cpp
|
whisper.cpp
|
||||||
)
|
)
|
||||||
@ -421,9 +507,15 @@ if (BUILD_SHARED_LIBS)
|
|||||||
WHISPER_BUILD
|
WHISPER_BUILD
|
||||||
GGML_BUILD
|
GGML_BUILD
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (WHISPER_METAL)
|
||||||
|
# TODO: I think this should make ggml-metal.m "see" the ggml-metal.metal file from the "bin" directory
|
||||||
|
# but for some reason it does not work here like it does in llama.cpp
|
||||||
|
set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (GGML_CUDA_SOURCES)
|
if (GGML_SOURCES_CUDA)
|
||||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||||
@ -439,10 +531,13 @@ target_compile_definitions(${TARGET} PUBLIC
|
|||||||
|
|
||||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||||
|
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
install(TARGETS ${TARGET}
|
install(TARGETS ${TARGET}
|
||||||
LIBRARY DESTINATION lib
|
LIBRARY DESTINATION lib
|
||||||
ARCHIVE DESTINATION lib/static
|
ARCHIVE DESTINATION lib/static
|
||||||
RUNTIME DESTINATION bin
|
RUNTIME DESTINATION bin
|
||||||
|
RESOURCE DESTINATION bin
|
||||||
PUBLIC_HEADER DESTINATION include
|
PUBLIC_HEADER DESTINATION include
|
||||||
)
|
)
|
||||||
|
|
||||||
|
73
Makefile
73
Makefile
@ -42,18 +42,55 @@ CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
|||||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||||
LDFLAGS =
|
LDFLAGS =
|
||||||
|
|
||||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
|
# clock_gettime came in POSIX.1b (1993)
|
||||||
ifneq ($(wildcard /usr/include/musl/*),)
|
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
||||||
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
# posix_memalign came in POSIX.1-2001 / SUSv3
|
||||||
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
|
||||||
|
CFLAGS += -D_XOPEN_SOURCE=600
|
||||||
|
CXXFLAGS += -D_XOPEN_SOURCE=600
|
||||||
|
|
||||||
|
# Somehow in OpenBSD whenever POSIX conformance is specified
|
||||||
|
# some string functions rely on locale_t availability,
|
||||||
|
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||||
|
ifeq ($(UNAME_S),OpenBSD)
|
||||||
|
CFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700
|
||||||
|
CXXFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Data types, macros and functions related to controlling CPU affinity
|
||||||
|
# are available on Linux through GNU extensions in libc
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
CFLAGS += -D_GNU_SOURCE
|
||||||
|
CXXFLAGS += -D_GNU_SOURCE
|
||||||
endif
|
endif
|
||||||
|
|
||||||
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
||||||
# and on macOS its availability depends on enabling Darwin extensions
|
# and on macOS its availability depends on enabling Darwin extensions
|
||||||
|
# similarly on DragonFly, enabling BSD extensions is necessary
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
CFLAGS += -D_DARWIN_C_SOURCE
|
CFLAGS += -D_DARWIN_C_SOURCE
|
||||||
CXXFLAGS += -D_DARWIN_C_SOURCE
|
CXXFLAGS += -D_DARWIN_C_SOURCE
|
||||||
endif
|
endif
|
||||||
|
ifeq ($(UNAME_S),DragonFly)
|
||||||
|
CFLAGS += -D__BSD_VISIBLE
|
||||||
|
CXXFLAGS += -D__BSD_VISIBLE
|
||||||
|
endif
|
||||||
|
|
||||||
|
# alloca is a non-standard interface that is not visible on BSDs when
|
||||||
|
# POSIX conformance is specified, but not all of them provide a clean way
|
||||||
|
# to enable it in such cases
|
||||||
|
ifeq ($(UNAME_S),FreeBSD)
|
||||||
|
CFLAGS += -D__BSD_VISIBLE
|
||||||
|
CXXFLAGS += -D__BSD_VISIBLE
|
||||||
|
endif
|
||||||
|
ifeq ($(UNAME_S),NetBSD)
|
||||||
|
CFLAGS += -D_NETBSD_SOURCE
|
||||||
|
CXXFLAGS += -D_NETBSD_SOURCE
|
||||||
|
endif
|
||||||
|
ifeq ($(UNAME_S),OpenBSD)
|
||||||
|
CFLAGS += -D_BSD_SOURCE
|
||||||
|
CXXFLAGS += -D_BSD_SOURCE
|
||||||
|
endif
|
||||||
|
|
||||||
# OS specific
|
# OS specific
|
||||||
# TODO: support Windows
|
# TODO: support Windows
|
||||||
@ -67,7 +104,7 @@ endif
|
|||||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
|
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
CPUINFO_CMD := sysctl machdep.cpu.features
|
CPUINFO_CMD := sysctl machdep.cpu.features machdep.cpu.leaf7_features
|
||||||
else ifeq ($(UNAME_S),Linux)
|
else ifeq ($(UNAME_S),Linux)
|
||||||
CPUINFO_CMD := cat /proc/cpuinfo
|
CPUINFO_CMD := cat /proc/cpuinfo
|
||||||
else ifneq (,$(filter MINGW32_NT% MINGW64_NT%,$(UNAME_S)))
|
else ifneq (,$(filter MINGW32_NT% MINGW64_NT%,$(UNAME_S)))
|
||||||
@ -145,6 +182,16 @@ ifdef WHISPER_COREML_ALLOW_FALLBACK
|
|||||||
endif
|
endif
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifndef WHISPER_NO_METAL
|
||||||
|
ifeq ($(UNAME_S),Darwin)
|
||||||
|
WHISPER_METAL := 1
|
||||||
|
|
||||||
|
CFLAGS += -DGGML_USE_METAL
|
||||||
|
CXXFLAGS += -DGGML_USE_METAL
|
||||||
|
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||||
|
endif
|
||||||
|
endif
|
||||||
|
|
||||||
ifdef WHISPER_OPENBLAS
|
ifdef WHISPER_OPENBLAS
|
||||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
|
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
|
||||||
LDFLAGS += -lopenblas
|
LDFLAGS += -lopenblas
|
||||||
@ -251,6 +298,11 @@ $(info )
|
|||||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||||
$(CC) $(CFLAGS) -c $< -o $@
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
|
||||||
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
WHISPER_OBJ += ggml-alloc.o
|
||||||
|
|
||||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
@ -266,6 +318,13 @@ whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-imp
|
|||||||
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifdef WHISPER_METAL
|
||||||
|
ggml-metal.o: ggml-metal.m ggml-metal.h
|
||||||
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
WHISPER_OBJ += ggml-metal.o
|
||||||
|
endif
|
||||||
|
|
||||||
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
||||||
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
||||||
|
|
||||||
@ -297,8 +356,8 @@ quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
|
|||||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||||
|
|
||||||
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||||
|
|
||||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||||
|
38
README.md
38
README.md
@ -11,14 +11,14 @@ Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / S
|
|||||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||||
|
|
||||||
- Plain C/C++ implementation without dependencies
|
- Plain C/C++ implementation without dependencies
|
||||||
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
- Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
||||||
- AVX intrinsics support for x86 architectures
|
- AVX intrinsics support for x86 architectures
|
||||||
- VSX intrinsics support for POWER architectures
|
- VSX intrinsics support for POWER architectures
|
||||||
- Mixed F16 / F32 precision
|
- Mixed F16 / F32 precision
|
||||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||||
- Low memory usage (Flash Attention)
|
- Low memory usage (Flash Attention)
|
||||||
- Zero memory allocations at runtime
|
- Zero memory allocations at runtime
|
||||||
- Runs on the CPU
|
- Support for CPU-only inference
|
||||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||||
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
||||||
@ -50,6 +50,10 @@ You can also easily make your own offline voice assistant application: [command]
|
|||||||
|
|
||||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||||
|
|
||||||
|
On Apple Silicon, the inference runs fully on the GPU via Metal:
|
||||||
|
|
||||||
|
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
|
||||||
|
|
||||||
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
||||||
|
|
||||||
## Implementation details
|
## Implementation details
|
||||||
@ -109,30 +113,37 @@ options:
|
|||||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
-sow, --split-on-word [false ] split on word rather than on token
|
||||||
|
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||||
-tr, --translate [false ] translate from source language to english
|
-tr, --translate [false ] translate from source language to english
|
||||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
|
||||||
-di, --diarize [false ] stereo audio diarization
|
-di, --diarize [false ] stereo audio diarization
|
||||||
|
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
||||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||||
-otxt, --output-txt [false ] output result in a text file
|
-otxt, --output-txt [false ] output result in a text file
|
||||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||||
-osrt, --output-srt [false ] output result in a srt file
|
-osrt, --output-srt [false ] output result in a srt file
|
||||||
|
-olrc, --output-lrc [false ] output result in a lrc file
|
||||||
-owts, --output-words [false ] output script for generating karaoke video
|
-owts, --output-words [false ] output script for generating karaoke video
|
||||||
|
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||||
-ocsv, --output-csv [false ] output result in a CSV file
|
-ocsv, --output-csv [false ] output result in a CSV file
|
||||||
|
-oj, --output-json [false ] output result in a JSON file
|
||||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||||
-ps, --print-special [false ] print special tokens
|
-ps, --print-special [false ] print special tokens
|
||||||
-pc, --print-colors [false ] print colors
|
-pc, --print-colors [false ] print colors
|
||||||
-pp, --print-progress [false ] print progress
|
-pp, --print-progress [false ] print progress
|
||||||
-nt, --no-timestamps [true ] do not print timestamps
|
-nt, --no-timestamps [false ] do not print timestamps
|
||||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||||
|
-dl, --detect-language [false ] exit after automatically detecting language
|
||||||
--prompt PROMPT [ ] initial prompt
|
--prompt PROMPT [ ] initial prompt
|
||||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||||
-f FNAME, --file FNAME [ ] input WAV file path
|
-f FNAME, --file FNAME [ ] input WAV file path
|
||||||
|
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||||
|
-ls, --log-score [false ] log best decoder scores of token
|
||||||
|
|
||||||
|
|
||||||
bash ./models/download-ggml-model.sh base.en
|
bash ./models/download-ggml-model.sh base.en
|
||||||
@ -698,6 +709,19 @@ took to execute it. The results are summarized in the following Github issue:
|
|||||||
|
|
||||||
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
|
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
|
||||||
|
|
||||||
|
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py).
|
||||||
|
|
||||||
|
You can run it with the following command, by default it will run against any standard model in the models folder.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2
|
||||||
|
```
|
||||||
|
|
||||||
|
It is written in python with the intention of being easy to modify and extend for your benchmarking use case.
|
||||||
|
|
||||||
|
It outputs a csv file with the results of the benchmarking.
|
||||||
|
|
||||||
|
|
||||||
## ggml format
|
## ggml format
|
||||||
|
|
||||||
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
|
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
|
||||||
@ -719,7 +743,7 @@ in [models](models).
|
|||||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
||||||
|
|
||||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
- [X] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||||
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
||||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||||
- [X] Java:
|
- [X] Java:
|
||||||
|
@ -118,6 +118,11 @@ func (p *Params) SetMaxTokensPerSegment(n int) {
|
|||||||
p.max_tokens = C.int(n)
|
p.max_tokens = C.int(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set audio encoder context
|
||||||
|
func (p *Params) SetAudioCtx(n int) {
|
||||||
|
p.audio_ctx = C.int(n)
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// PRIVATE METHODS
|
// PRIVATE METHODS
|
||||||
|
|
||||||
@ -141,6 +146,7 @@ func (p *Params) String() string {
|
|||||||
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||||
|
str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx)
|
||||||
if p.translate {
|
if p.translate {
|
||||||
str += " translate"
|
str += " translate"
|
||||||
}
|
}
|
||||||
|
@ -125,6 +125,11 @@ func (context *context) SetMaxTokensPerSegment(n uint) {
|
|||||||
context.params.SetMaxTokensPerSegment(int(n))
|
context.params.SetMaxTokensPerSegment(int(n))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set audio encoder context
|
||||||
|
func (context *context) SetAudioCtx(n uint) {
|
||||||
|
context.params.SetAudioCtx(int(n))
|
||||||
|
}
|
||||||
|
|
||||||
// ResetTimings resets the mode timings. Should be called before processing
|
// ResetTimings resets the mode timings. Should be called before processing
|
||||||
func (context *context) ResetTimings() {
|
func (context *context) ResetTimings() {
|
||||||
context.model.ctx.Whisper_reset_timings()
|
context.model.ctx.Whisper_reset_timings()
|
||||||
|
@ -48,6 +48,7 @@ type Context interface {
|
|||||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||||
|
SetAudioCtx(uint) // Set audio encoder context
|
||||||
|
|
||||||
// Process mono audio data and return any errors.
|
// Process mono audio data and return any errors.
|
||||||
// If defined, newly generated segments are passed to the
|
// If defined, newly generated segments are passed to the
|
||||||
|
Submodule bindings/ios updated: de46d9e781...22a9eef021
@ -2,6 +2,7 @@ plugins {
|
|||||||
id 'java'
|
id 'java'
|
||||||
id 'java-library'
|
id 'java-library'
|
||||||
id 'maven-publish'
|
id 'maven-publish'
|
||||||
|
id 'signing'
|
||||||
}
|
}
|
||||||
|
|
||||||
archivesBaseName = 'whispercpp'
|
archivesBaseName = 'whispercpp'
|
||||||
@ -109,4 +110,23 @@ publishing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
maven {
|
||||||
|
def releasesRepoUrl = 'https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/'
|
||||||
|
def snapshotsRepoUrl = 'https://s01.oss.sonatype.org/content/repositories/snapshots/'
|
||||||
|
url = version.endsWith('-SNAPSHOT') ? snapshotsRepoUrl : releasesRepoUrl
|
||||||
|
credentials {
|
||||||
|
username = System.getenv("MAVEN_USERNAME")
|
||||||
|
password = System.getenv("MAVEN_PASSWORD")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
signing {
|
||||||
|
def signingKey = System.getenv("PGP_SECRET")
|
||||||
|
def signingPassword = System.getenv("PGP_PASSPHRASE")
|
||||||
|
useInMemoryPgpKeys(signingKey, signingPassword)
|
||||||
|
sign publishing.publications.mavenJava
|
||||||
}
|
}
|
||||||
|
2
bindings/ruby/ext/.gitignore
vendored
2
bindings/ruby/ext/.gitignore
vendored
@ -1,6 +1,8 @@
|
|||||||
Makefile
|
Makefile
|
||||||
ggml.c
|
ggml.c
|
||||||
ggml.h
|
ggml.h
|
||||||
|
ggml-alloc.c
|
||||||
|
ggml-alloc.h
|
||||||
whisper.bundle
|
whisper.bundle
|
||||||
whisper.cpp
|
whisper.cpp
|
||||||
whisper.h
|
whisper.h
|
||||||
|
@ -3,6 +3,8 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
|||||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
||||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
||||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.h')} .")
|
||||||
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.c')} .")
|
||||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
|
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,13 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
|
|||||||
|
|
||||||
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
||||||
|
|
||||||
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
|
// select which device to run the Core ML model on
|
||||||
|
MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
|
||||||
|
config.computeUnits = MLComputeUnitsCPUAndGPU;
|
||||||
|
//config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
|
||||||
|
//config.computeUnits = MLComputeUnitsAll;
|
||||||
|
|
||||||
|
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);
|
||||||
|
|
||||||
if (data == NULL) {
|
if (data == NULL) {
|
||||||
return NULL;
|
return NULL;
|
||||||
|
@ -23,7 +23,6 @@ add_library(${TARGET} STATIC
|
|||||||
common.cpp
|
common.cpp
|
||||||
common-ggml.h
|
common-ggml.h
|
||||||
common-ggml.cpp
|
common-ggml.cpp
|
||||||
grammar-parser.cpp
|
|
||||||
)
|
)
|
||||||
|
|
||||||
include(DefaultTargetOptions)
|
include(DefaultTargetOptions)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
@ -44,13 +45,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||||
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
||||||
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
|
fprintf(stderr, " %-7s 0 - whisper\n", "");
|
||||||
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
||||||
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
int whisper_bench_encoder(const whisper_params & params) {
|
int whisper_bench_full(const whisper_params & params) {
|
||||||
// whisper init
|
// whisper init
|
||||||
|
|
||||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||||
@ -69,12 +70,49 @@ int whisper_bench_encoder(const whisper_params & params) {
|
|||||||
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
// heat encoder
|
||||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
whisper_token tokens[512];
|
||||||
|
memset(tokens, 0, sizeof(tokens));
|
||||||
|
|
||||||
|
// prompt heat
|
||||||
|
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||||
|
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// text-generation heat
|
||||||
|
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||||
|
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
whisper_reset_timings(ctx);
|
||||||
|
|
||||||
|
// actual run
|
||||||
|
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||||
|
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||||
|
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 256; i++) {
|
||||||
|
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||||
|
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
whisper_print_timings(ctx);
|
whisper_print_timings(ctx);
|
||||||
whisper_free(ctx);
|
whisper_free(ctx);
|
||||||
|
|
||||||
@ -103,7 +141,7 @@ int main(int argc, char ** argv) {
|
|||||||
int ret = -1;
|
int ret = -1;
|
||||||
|
|
||||||
switch (params.what) {
|
switch (params.what) {
|
||||||
case 0: ret = whisper_bench_encoder(params); break;
|
case 0: ret = whisper_bench_full(params); break;
|
||||||
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
||||||
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
||||||
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
||||||
|
@ -6,10 +6,9 @@
|
|||||||
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "common-sdl.h"
|
#include "common-sdl.h"
|
||||||
|
#include "common.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
#include "grammar-parser.h"
|
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -22,11 +21,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
bool file_exists(const std::string & fname) {
|
|
||||||
std::ifstream f(fname.c_str());
|
|
||||||
return f.good();
|
|
||||||
}
|
|
||||||
|
|
||||||
// command-line parameters
|
// command-line parameters
|
||||||
struct whisper_params {
|
struct whisper_params {
|
||||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
@ -39,10 +33,6 @@ struct whisper_params {
|
|||||||
float vad_thold = 0.6f;
|
float vad_thold = 0.6f;
|
||||||
float freq_thold = 100.0f;
|
float freq_thold = 100.0f;
|
||||||
|
|
||||||
float grammar_penalty = 100.0f;
|
|
||||||
|
|
||||||
grammar_parser::parse_state grammar_parsed;
|
|
||||||
|
|
||||||
bool speed_up = false;
|
bool speed_up = false;
|
||||||
bool translate = false;
|
bool translate = false;
|
||||||
bool print_special = false;
|
bool print_special = false;
|
||||||
@ -54,8 +44,6 @@ struct whisper_params {
|
|||||||
std::string fname_out;
|
std::string fname_out;
|
||||||
std::string commands;
|
std::string commands;
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::string context;
|
|
||||||
std::string grammar;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||||
@ -85,9 +73,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
|
||||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
|
||||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
whisper_print_usage(argc, argv, params);
|
whisper_print_usage(argc, argv, params);
|
||||||
@ -121,30 +106,16 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
|
||||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
|
||||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string transcribe(
|
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||||
whisper_context * ctx,
|
|
||||||
const whisper_params & params,
|
|
||||||
const std::vector<float> & pcmf32,
|
|
||||||
const std::string & grammar_rule,
|
|
||||||
float & logprob_min,
|
|
||||||
float & logprob_sum,
|
|
||||||
int & n_tokens,
|
|
||||||
int64_t & t_ms) {
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
logprob_min = 0.0f;
|
prob = 0.0f;
|
||||||
logprob_sum = 0.0f;
|
|
||||||
n_tokens = 0;
|
|
||||||
t_ms = 0;
|
t_ms = 0;
|
||||||
|
|
||||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
|
||||||
|
|
||||||
wparams.print_progress = false;
|
wparams.print_progress = false;
|
||||||
wparams.print_special = params.print_special;
|
wparams.print_special = params.print_special;
|
||||||
@ -152,7 +123,6 @@ std::string transcribe(
|
|||||||
wparams.print_timestamps = !params.no_timestamps;
|
wparams.print_timestamps = !params.no_timestamps;
|
||||||
wparams.translate = params.translate;
|
wparams.translate = params.translate;
|
||||||
wparams.no_context = true;
|
wparams.no_context = true;
|
||||||
wparams.no_timestamps = params.no_timestamps;
|
|
||||||
wparams.single_segment = true;
|
wparams.single_segment = true;
|
||||||
wparams.max_tokens = params.max_tokens;
|
wparams.max_tokens = params.max_tokens;
|
||||||
wparams.language = params.language.c_str();
|
wparams.language = params.language.c_str();
|
||||||
@ -161,28 +131,11 @@ std::string transcribe(
|
|||||||
wparams.audio_ctx = params.audio_ctx;
|
wparams.audio_ctx = params.audio_ctx;
|
||||||
wparams.speed_up = params.speed_up;
|
wparams.speed_up = params.speed_up;
|
||||||
|
|
||||||
wparams.temperature = 0.4f;
|
|
||||||
wparams.temperature_inc = 1.0f;
|
|
||||||
wparams.greedy.best_of = 5;
|
|
||||||
|
|
||||||
wparams.beam_search.beam_size = 5;
|
|
||||||
|
|
||||||
wparams.initial_prompt = params.context.data();
|
|
||||||
|
|
||||||
const auto & grammar_parsed = params.grammar_parsed;
|
|
||||||
auto grammar_rules = grammar_parsed.c_rules();
|
|
||||||
|
|
||||||
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
|
||||||
wparams.grammar_rules = grammar_rules.data();
|
|
||||||
wparams.n_grammar_rules = grammar_rules.size();
|
|
||||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
|
||||||
wparams.grammar_penalty = params.grammar_penalty;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int prob_n = 0;
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
@ -191,17 +144,19 @@ std::string transcribe(
|
|||||||
|
|
||||||
result += text;
|
result += text;
|
||||||
|
|
||||||
const int n = whisper_full_n_tokens(ctx, i);
|
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||||
for (int j = 0; j < n; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||||
|
|
||||||
if(token.plog > 0.0f) exit(0);
|
prob += token.p;
|
||||||
logprob_min = std::min(logprob_min, token.plog);
|
++prob_n;
|
||||||
logprob_sum += token.plog;
|
|
||||||
++n_tokens;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (prob_n > 0) {
|
||||||
|
prob /= prob_n;
|
||||||
|
}
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||||
|
|
||||||
@ -460,9 +415,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
bool is_running = true;
|
bool is_running = true;
|
||||||
bool ask_prompt = true;
|
bool ask_prompt = true;
|
||||||
|
|
||||||
float logprob_min = 0.0f;
|
float prob = 0.0f;
|
||||||
float logprob_sum = 0.0f;
|
|
||||||
int n_tokens = 0;
|
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
std::vector<float> pcmf32_cur;
|
||||||
|
|
||||||
@ -500,7 +453,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
// detect the commands
|
// detect the commands
|
||||||
audio.get(params.command_ms, pcmf32_cur);
|
audio.get(params.command_ms, pcmf32_cur);
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||||
|
|
||||||
const auto words = get_words(txt);
|
const auto words = get_words(txt);
|
||||||
|
|
||||||
@ -541,22 +494,13 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
bool have_prompt = false;
|
bool have_prompt = false;
|
||||||
bool ask_prompt = true;
|
bool ask_prompt = true;
|
||||||
|
|
||||||
float logprob_min0 = 0.0f;
|
float prob0 = 0.0f;
|
||||||
float logprob_min = 0.0f;
|
float prob = 0.0f;
|
||||||
|
|
||||||
float logprob_sum0 = 0.0f;
|
|
||||||
float logprob_sum = 0.0f;
|
|
||||||
|
|
||||||
int n_tokens0 = 0;
|
|
||||||
int n_tokens = 0;
|
|
||||||
|
|
||||||
std::vector<float> pcmf32_cur;
|
std::vector<float> pcmf32_cur;
|
||||||
std::vector<float> pcmf32_prompt;
|
std::vector<float> pcmf32_prompt;
|
||||||
|
|
||||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||||
if (!params.prompt.empty()) {
|
|
||||||
k_prompt = params.prompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||||
@ -589,11 +533,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
// wait for activation phrase
|
// wait for activation phrase
|
||||||
audio.get(params.prompt_ms, pcmf32_cur);
|
audio.get(params.prompt_ms, pcmf32_cur);
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||||
|
|
||||||
const float p = 100.0f * std::exp(logprob_min0);
|
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||||
|
|
||||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
|
||||||
|
|
||||||
const float sim = similarity(txt, k_prompt);
|
const float sim = similarity(txt, k_prompt);
|
||||||
|
|
||||||
@ -614,30 +556,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
// we have heard the activation phrase, now detect the commands
|
// we have heard the activation phrase, now detect the commands
|
||||||
audio.get(params.command_ms, pcmf32_cur);
|
audio.get(params.command_ms, pcmf32_cur);
|
||||||
|
|
||||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
|
||||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
|
||||||
|
|
||||||
// prepend 3 second of silence
|
|
||||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
|
||||||
|
|
||||||
// prepend the prompt audio
|
// prepend the prompt audio
|
||||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||||
|
|
||||||
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
prob = 100.0f*(prob - prob0);
|
||||||
const float p = 100.0f * std::exp(logprob_min);
|
|
||||||
|
|
||||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||||
|
|
||||||
// find the prompt in the text
|
// find the prompt in the text
|
||||||
float best_sim = 0.0f;
|
float best_sim = 0.0f;
|
||||||
size_t best_len = 0;
|
size_t best_len = 0;
|
||||||
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||||
if (n >= txt.size()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto prompt = txt.substr(0, n);
|
const auto prompt = txt.substr(0, n);
|
||||||
|
|
||||||
const float sim = similarity(prompt, k_prompt);
|
const float sim = similarity(prompt, k_prompt);
|
||||||
@ -650,16 +581,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
|
||||||
if (best_len == 0) {
|
|
||||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
|
||||||
} else {
|
|
||||||
// cut the prompt from the decoded text
|
|
||||||
const std::string command = ::trim(txt.substr(best_len));
|
const std::string command = ::trim(txt.substr(best_len));
|
||||||
|
|
||||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stdout, "\n");
|
fprintf(stdout, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -724,37 +648,13 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
int ret_val = 0;
|
int ret_val = 0;
|
||||||
|
|
||||||
if (!params.grammar.empty()) {
|
|
||||||
auto & grammar = params.grammar_parsed;
|
|
||||||
if (file_exists(params.grammar.c_str())) {
|
|
||||||
// read grammar from file
|
|
||||||
std::ifstream ifs(params.grammar.c_str());
|
|
||||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
|
||||||
grammar = grammar_parser::parse(txt.c_str());
|
|
||||||
} else {
|
|
||||||
// read grammar from string
|
|
||||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (grammar.rules.empty()) {
|
|
||||||
ret_val = 1;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
|
||||||
grammar_parser::print_grammar(stderr, grammar);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ret_val == 0) {
|
|
||||||
if (!params.commands.empty()) {
|
if (!params.commands.empty()) {
|
||||||
ret_val = process_command_list(ctx, audio, params);
|
ret_val = process_command_list(ctx, audio, params);
|
||||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
} else if (!params.prompt.empty()) {
|
||||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||||
} else {
|
} else {
|
||||||
ret_val = process_general_transcription(ctx, audio, params);
|
ret_val = process_general_transcription(ctx, audio, params);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
audio.pause();
|
audio.pause();
|
||||||
|
|
||||||
|
@ -792,7 +792,7 @@ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void sam_print_usage(int argc, char ** argv, const sam_params & params) {
|
void sam_print_usage(int /*argc*/, char ** argv, const sam_params & params) {
|
||||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "options:\n");
|
fprintf(stderr, "options:\n");
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <ctime>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#define COMMON_SAMPLE_RATE 16000
|
#define COMMON_SAMPLE_RATE 16000
|
||||||
|
|
||||||
@ -139,6 +141,104 @@ bool read_wav(
|
|||||||
std::vector<std::vector<float>> & pcmf32s,
|
std::vector<std::vector<float>> & pcmf32s,
|
||||||
bool stereo);
|
bool stereo);
|
||||||
|
|
||||||
|
// Write PCM data into WAV audio file
|
||||||
|
class wav_writer {
|
||||||
|
private:
|
||||||
|
std::ofstream file;
|
||||||
|
uint32_t dataSize = 0;
|
||||||
|
std::string wav_filename;
|
||||||
|
|
||||||
|
bool write_header(const uint32_t sample_rate,
|
||||||
|
const uint16_t bits_per_sample,
|
||||||
|
const uint16_t channels) {
|
||||||
|
|
||||||
|
file.write("RIFF", 4);
|
||||||
|
file.write("\0\0\0\0", 4); // Placeholder for file size
|
||||||
|
file.write("WAVE", 4);
|
||||||
|
file.write("fmt ", 4);
|
||||||
|
|
||||||
|
const uint32_t sub_chunk_size = 16;
|
||||||
|
const uint16_t audio_format = 1; // PCM format
|
||||||
|
const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8;
|
||||||
|
const uint16_t block_align = channels * bits_per_sample / 8;
|
||||||
|
|
||||||
|
file.write(reinterpret_cast<const char *>(&sub_chunk_size), 4);
|
||||||
|
file.write(reinterpret_cast<const char *>(&audio_format), 2);
|
||||||
|
file.write(reinterpret_cast<const char *>(&channels), 2);
|
||||||
|
file.write(reinterpret_cast<const char *>(&sample_rate), 4);
|
||||||
|
file.write(reinterpret_cast<const char *>(&byte_rate), 4);
|
||||||
|
file.write(reinterpret_cast<const char *>(&block_align), 2);
|
||||||
|
file.write(reinterpret_cast<const char *>(&bits_per_sample), 2);
|
||||||
|
file.write("data", 4);
|
||||||
|
file.write("\0\0\0\0", 4); // Placeholder for data size
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is assumed that PCM data is normalized to a range from -1 to 1
|
||||||
|
bool write_audio(const float * data, size_t length) {
|
||||||
|
for (size_t i = 0; i < length; ++i) {
|
||||||
|
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
|
||||||
|
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
||||||
|
dataSize += sizeof(int16_t);
|
||||||
|
}
|
||||||
|
if (file.is_open()) {
|
||||||
|
file.seekp(4, std::ios::beg);
|
||||||
|
uint32_t fileSize = 36 + dataSize;
|
||||||
|
file.write(reinterpret_cast<char *>(&fileSize), 4);
|
||||||
|
file.seekp(40, std::ios::beg);
|
||||||
|
file.write(reinterpret_cast<char *>(&dataSize), 4);
|
||||||
|
file.seekp(0, std::ios::end);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool open_wav(const std::string & filename) {
|
||||||
|
if (filename != wav_filename) {
|
||||||
|
if (file.is_open()) {
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!file.is_open()) {
|
||||||
|
file.open(filename, std::ios::binary);
|
||||||
|
wav_filename = filename;
|
||||||
|
dataSize = 0;
|
||||||
|
}
|
||||||
|
return file.is_open();
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
bool open(const std::string & filename,
|
||||||
|
const uint32_t sample_rate,
|
||||||
|
const uint16_t bits_per_sample,
|
||||||
|
const uint16_t channels) {
|
||||||
|
|
||||||
|
if (open_wav(filename)) {
|
||||||
|
write_header(sample_rate, bits_per_sample, channels);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool close() {
|
||||||
|
file.close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool write(const float * data, size_t length) {
|
||||||
|
return write_audio(data, length);
|
||||||
|
}
|
||||||
|
|
||||||
|
~wav_writer() {
|
||||||
|
if (file.is_open()) {
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// Apply a high-pass frequency filter to PCM audio
|
// Apply a high-pass frequency filter to PCM audio
|
||||||
// Suppresses frequencies below cutoff Hz
|
// Suppresses frequencies below cutoff Hz
|
||||||
void high_pass_filter(
|
void high_pass_filter(
|
||||||
|
@ -1,423 +0,0 @@
|
|||||||
#include "grammar-parser.h"
|
|
||||||
#include <cstdint>
|
|
||||||
#include <cwchar>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <exception>
|
|
||||||
|
|
||||||
namespace grammar_parser {
|
|
||||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
|
||||||
// copied from whisper.cpp
|
|
||||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
|
||||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
|
||||||
uint8_t highbits = first_byte >> 4;
|
|
||||||
int len = lookup[highbits];
|
|
||||||
uint8_t mask = (1 << (8 - len)) - 1;
|
|
||||||
uint32_t value = first_byte & mask;
|
|
||||||
const char * end = src + len; // may overrun!
|
|
||||||
const char * pos = src + 1;
|
|
||||||
for ( ; pos < end && *pos; pos++) {
|
|
||||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
|
||||||
}
|
|
||||||
return std::make_pair(value, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
|
||||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
|
||||||
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
|
||||||
return result.first->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
|
||||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
|
||||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
|
||||||
return next_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
void add_rule(
|
|
||||||
parse_state & state,
|
|
||||||
uint32_t rule_id,
|
|
||||||
const std::vector<whisper_grammar_element> & rule) {
|
|
||||||
if (state.rules.size() <= rule_id) {
|
|
||||||
state.rules.resize(rule_id + 1);
|
|
||||||
}
|
|
||||||
state.rules[rule_id] = rule;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_word_char(char c) {
|
|
||||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
|
||||||
const char * pos = src;
|
|
||||||
const char * end = src + size;
|
|
||||||
uint32_t value = 0;
|
|
||||||
for ( ; pos < end && *pos; pos++) {
|
|
||||||
value <<= 4;
|
|
||||||
char c = *pos;
|
|
||||||
if ('a' <= c && c <= 'f') {
|
|
||||||
value += c - 'a' + 10;
|
|
||||||
} else if ('A' <= c && c <= 'F') {
|
|
||||||
value += c - 'A' + 10;
|
|
||||||
} else if ('0' <= c && c <= '9') {
|
|
||||||
value += c - '0';
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (pos != end) {
|
|
||||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
|
||||||
}
|
|
||||||
return std::make_pair(value, pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_space(const char * src, bool newline_ok) {
|
|
||||||
const char * pos = src;
|
|
||||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
|
||||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
|
||||||
if (*pos == '#') {
|
|
||||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_name(const char * src) {
|
|
||||||
const char * pos = src;
|
|
||||||
while (is_word_char(*pos)) {
|
|
||||||
pos++;
|
|
||||||
}
|
|
||||||
if (pos == src) {
|
|
||||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
|
||||||
if (*src == '\\') {
|
|
||||||
switch (src[1]) {
|
|
||||||
case 'x': return parse_hex(src + 2, 2);
|
|
||||||
case 'u': return parse_hex(src + 2, 4);
|
|
||||||
case 'U': return parse_hex(src + 2, 8);
|
|
||||||
case 't': return std::make_pair('\t', src + 2);
|
|
||||||
case 'r': return std::make_pair('\r', src + 2);
|
|
||||||
case 'n': return std::make_pair('\n', src + 2);
|
|
||||||
case '\\':
|
|
||||||
case '"':
|
|
||||||
case '[':
|
|
||||||
case ']':
|
|
||||||
return std::make_pair(src[1], src + 2);
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
|
||||||
}
|
|
||||||
} else if (*src) {
|
|
||||||
return decode_utf8(src);
|
|
||||||
}
|
|
||||||
throw std::runtime_error("unexpected end of input");
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_alternates(
|
|
||||||
parse_state & state,
|
|
||||||
const char * src,
|
|
||||||
const std::string & rule_name,
|
|
||||||
uint32_t rule_id,
|
|
||||||
bool is_nested);
|
|
||||||
|
|
||||||
const char * parse_sequence(
|
|
||||||
parse_state & state,
|
|
||||||
const char * src,
|
|
||||||
const std::string & rule_name,
|
|
||||||
std::vector<whisper_grammar_element> & out_elements,
|
|
||||||
bool is_nested) {
|
|
||||||
size_t last_sym_start = out_elements.size();
|
|
||||||
const char * pos = src;
|
|
||||||
while (*pos) {
|
|
||||||
if (*pos == '"') { // literal string
|
|
||||||
pos++;
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
while (*pos != '"') {
|
|
||||||
auto char_pair = parse_char(pos);
|
|
||||||
pos = char_pair.second;
|
|
||||||
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '[') { // char range(s)
|
|
||||||
pos++;
|
|
||||||
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
|
|
||||||
if (*pos == '^') {
|
|
||||||
pos++;
|
|
||||||
start_type = WHISPER_GRETYPE_CHAR_NOT;
|
|
||||||
}
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
while (*pos != ']') {
|
|
||||||
auto char_pair = parse_char(pos);
|
|
||||||
pos = char_pair.second;
|
|
||||||
enum whisper_gretype type = last_sym_start < out_elements.size()
|
|
||||||
? WHISPER_GRETYPE_CHAR_ALT
|
|
||||||
: start_type;
|
|
||||||
|
|
||||||
out_elements.push_back({type, char_pair.first});
|
|
||||||
if (pos[0] == '-' && pos[1] != ']') {
|
|
||||||
auto endchar_pair = parse_char(pos + 1);
|
|
||||||
pos = endchar_pair.second;
|
|
||||||
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (is_word_char(*pos)) { // rule reference
|
|
||||||
const char * name_end = parse_name(pos);
|
|
||||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
|
||||||
pos = parse_space(name_end, is_nested);
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
|
|
||||||
} else if (*pos == '(') { // grouping
|
|
||||||
// parse nested alternates into synthesized rule
|
|
||||||
pos = parse_space(pos + 1, true);
|
|
||||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
|
||||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
|
||||||
last_sym_start = out_elements.size();
|
|
||||||
// output reference to synthesized rule
|
|
||||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
|
||||||
if (*pos != ')') {
|
|
||||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
|
||||||
if (last_sym_start == out_elements.size()) {
|
|
||||||
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
|
||||||
// rewrite rules:
|
|
||||||
// S* --> S' ::= S S' |
|
|
||||||
// S+ --> S' ::= S S' | S
|
|
||||||
// S? --> S' ::= S |
|
|
||||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
|
||||||
std::vector<whisper_grammar_element> sub_rule;
|
|
||||||
// add preceding symbol to generated rule
|
|
||||||
sub_rule.insert(
|
|
||||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
|
||||||
if (*pos == '*' || *pos == '+') {
|
|
||||||
// cause generated rule to recurse
|
|
||||||
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
|
||||||
}
|
|
||||||
// mark start of alternate def
|
|
||||||
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
|
||||||
if (*pos == '+') {
|
|
||||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
|
||||||
sub_rule.insert(
|
|
||||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
|
||||||
}
|
|
||||||
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
|
|
||||||
add_rule(state, sub_rule_id, sub_rule);
|
|
||||||
|
|
||||||
// in original rule, replace previous symbol with reference to generated rule
|
|
||||||
out_elements.resize(last_sym_start);
|
|
||||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
|
||||||
|
|
||||||
pos = parse_space(pos + 1, is_nested);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_alternates(
|
|
||||||
parse_state & state,
|
|
||||||
const char * src,
|
|
||||||
const std::string & rule_name,
|
|
||||||
uint32_t rule_id,
|
|
||||||
bool is_nested) {
|
|
||||||
std::vector<whisper_grammar_element> rule;
|
|
||||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
|
||||||
while (*pos == '|') {
|
|
||||||
rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
|
||||||
pos = parse_space(pos + 1, true);
|
|
||||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
|
||||||
}
|
|
||||||
rule.push_back({WHISPER_GRETYPE_END, 0});
|
|
||||||
add_rule(state, rule_id, rule);
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * parse_rule(parse_state & state, const char * src) {
|
|
||||||
const char * name_end = parse_name(src);
|
|
||||||
const char * pos = parse_space(name_end, false);
|
|
||||||
size_t name_len = name_end - src;
|
|
||||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
|
||||||
const std::string name(src, name_len);
|
|
||||||
|
|
||||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
|
||||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
|
||||||
}
|
|
||||||
pos = parse_space(pos + 3, true);
|
|
||||||
|
|
||||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
|
||||||
|
|
||||||
if (*pos == '\r') {
|
|
||||||
pos += pos[1] == '\n' ? 2 : 1;
|
|
||||||
} else if (*pos == '\n') {
|
|
||||||
pos++;
|
|
||||||
} else if (*pos) {
|
|
||||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
|
||||||
}
|
|
||||||
return parse_space(pos, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
parse_state parse(const char * src) {
|
|
||||||
try {
|
|
||||||
parse_state state;
|
|
||||||
const char * pos = parse_space(src, true);
|
|
||||||
while (*pos) {
|
|
||||||
pos = parse_rule(state, pos);
|
|
||||||
}
|
|
||||||
return state;
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
|
||||||
return parse_state();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_grammar_char(FILE * file, uint32_t c) {
|
|
||||||
if (0x20 <= c && c <= 0x7f) {
|
|
||||||
fprintf(file, "%c", static_cast<char>(c));
|
|
||||||
} else {
|
|
||||||
// cop out of encoding UTF-8
|
|
||||||
fprintf(file, "<U+%04X>", c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_char_element(whisper_grammar_element elem) {
|
|
||||||
switch (elem.type) {
|
|
||||||
case WHISPER_GRETYPE_CHAR: return true;
|
|
||||||
case WHISPER_GRETYPE_CHAR_NOT: return true;
|
|
||||||
case WHISPER_GRETYPE_CHAR_ALT: return true;
|
|
||||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
|
|
||||||
default: return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
|
|
||||||
for (auto elem : rule) {
|
|
||||||
switch (elem.type) {
|
|
||||||
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
|
|
||||||
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
|
||||||
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
|
||||||
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
|
||||||
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
|
||||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
|
||||||
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
|
||||||
}
|
|
||||||
switch (elem.type) {
|
|
||||||
case WHISPER_GRETYPE_END:
|
|
||||||
case WHISPER_GRETYPE_ALT:
|
|
||||||
case WHISPER_GRETYPE_RULE_REF:
|
|
||||||
fprintf(file, "(%u) ", elem.value);
|
|
||||||
break;
|
|
||||||
case WHISPER_GRETYPE_CHAR:
|
|
||||||
case WHISPER_GRETYPE_CHAR_NOT:
|
|
||||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
|
||||||
case WHISPER_GRETYPE_CHAR_ALT:
|
|
||||||
fprintf(file, "(\"");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
fprintf(file, "\") ");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fprintf(file, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_rule(
|
|
||||||
FILE * file,
|
|
||||||
uint32_t rule_id,
|
|
||||||
const std::vector<whisper_grammar_element> & rule,
|
|
||||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
|
||||||
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
|
|
||||||
}
|
|
||||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
|
||||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
|
||||||
whisper_grammar_element elem = rule[i];
|
|
||||||
switch (elem.type) {
|
|
||||||
case WHISPER_GRETYPE_END:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
|
||||||
std::to_string(i));
|
|
||||||
case WHISPER_GRETYPE_ALT:
|
|
||||||
fprintf(file, "| ");
|
|
||||||
break;
|
|
||||||
case WHISPER_GRETYPE_RULE_REF:
|
|
||||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
|
||||||
break;
|
|
||||||
case WHISPER_GRETYPE_CHAR:
|
|
||||||
fprintf(file, "[");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case WHISPER_GRETYPE_CHAR_NOT:
|
|
||||||
fprintf(file, "[^");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
|
||||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
|
||||||
std::to_string(rule_id) + "," + std::to_string(i));
|
|
||||||
}
|
|
||||||
fprintf(file, "-");
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
case WHISPER_GRETYPE_CHAR_ALT:
|
|
||||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
|
|
||||||
std::to_string(rule_id) + "," + std::to_string(i));
|
|
||||||
}
|
|
||||||
print_grammar_char(file, elem.value);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (is_char_element(elem)) {
|
|
||||||
switch (rule[i + 1].type) {
|
|
||||||
case WHISPER_GRETYPE_CHAR_ALT:
|
|
||||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
fprintf(file, "] ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fprintf(file, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_grammar(FILE * file, const parse_state & state) {
|
|
||||||
try {
|
|
||||||
std::map<uint32_t, std::string> symbol_id_names;
|
|
||||||
for (auto kv : state.symbol_ids) {
|
|
||||||
symbol_id_names[kv.second] = kv.first;
|
|
||||||
}
|
|
||||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
|
||||||
// fprintf(file, "%zu: ", i);
|
|
||||||
// print_rule_binary(file, state.rules[i]);
|
|
||||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
|
||||||
// fprintf(file, "\n");
|
|
||||||
}
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
|
||||||
std::vector<const whisper_grammar_element *> ret;
|
|
||||||
for (const auto & rule : rules) {
|
|
||||||
ret.push_back(rule.data());
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,29 +0,0 @@
|
|||||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
|
||||||
// binary context-free grammar format specified by whisper.h. Supports character
|
|
||||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
|
||||||
// arithmetic might look like:
|
|
||||||
//
|
|
||||||
// root ::= expr
|
|
||||||
// expr ::= term ([-+*/] term)*
|
|
||||||
// term ::= num | "(" space expr ")" space
|
|
||||||
// num ::= [0-9]+ space
|
|
||||||
// space ::= [ \t\n]*
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include "whisper.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace grammar_parser {
|
|
||||||
struct parse_state {
|
|
||||||
std::map<std::string, uint32_t> symbol_ids;
|
|
||||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
|
||||||
|
|
||||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
parse_state parse(const char * src);
|
|
||||||
void print_grammar(FILE * file, const parse_state & state);
|
|
||||||
}
|
|
@ -324,7 +324,7 @@ json register_commandset(struct whisper_context * ctx, json jparams, std::vector
|
|||||||
commandset_list.push_back(cs);
|
commandset_list.push_back(cs);
|
||||||
return json{{"index",index}};
|
return json{{"index",index}};
|
||||||
}
|
}
|
||||||
json seek(struct whisper_context * ctx, audio_async &audio, json params) {
|
json seek(struct whisper_context * /*ctx*/, audio_async & /*audio*/, json /*params*/) {
|
||||||
// whisper_state has the pertinent offsets, but there also seem to be a large
|
// whisper_state has the pertinent offsets, but there also seem to be a large
|
||||||
// number of scratch buffers that would prevent rewinding context in a manner similar to llama
|
// number of scratch buffers that would prevent rewinding context in a manner similar to llama
|
||||||
// I'll give this a another pass once everything else is implemented,
|
// I'll give this a another pass once everything else is implemented,
|
||||||
@ -412,7 +412,7 @@ void process_loop(struct whisper_context * ctx, audio_async &audio, const whispe
|
|||||||
jobqueue.pop_front();
|
jobqueue.pop_front();
|
||||||
// send response
|
// send response
|
||||||
std::string data = resp.dump(-1, ' ', false, json::error_handler_t::replace);
|
std::string data = resp.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||||
fprintf(stdout, "Content-Length: %d\r\n\r\n%s\n", data.length()+1, data.c_str());
|
fprintf(stdout, "Content-Length: %d\r\n\r\n%s\n", (int)data.length()+1, data.c_str());
|
||||||
std::cout.flush();
|
std::cout.flush();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -83,6 +83,7 @@ struct whisper_params {
|
|||||||
bool output_wts = false;
|
bool output_wts = false;
|
||||||
bool output_csv = false;
|
bool output_csv = false;
|
||||||
bool output_jsn = false;
|
bool output_jsn = false;
|
||||||
|
bool output_jsn_full = false;
|
||||||
bool output_lrc = false;
|
bool output_lrc = false;
|
||||||
bool print_special = false;
|
bool print_special = false;
|
||||||
bool print_colors = false;
|
bool print_colors = false;
|
||||||
@ -151,6 +152,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||||
|
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
|
||||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||||
@ -206,6 +208,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
||||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||||
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
|
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
|
||||||
|
fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
|
||||||
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
|
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
|
||||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||||
@ -260,7 +263,7 @@ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s
|
|||||||
|
|
||||||
return speaker;
|
return speaker;
|
||||||
}
|
}
|
||||||
void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
|
void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
|
||||||
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
|
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
|
||||||
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
|
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
|
||||||
if (progress >= *progress_prev + progress_step) {
|
if (progress >= *progress_prev + progress_step) {
|
||||||
@ -492,7 +495,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
|
||||||
std::ofstream fout(fname);
|
std::ofstream fout(fname);
|
||||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||||
|
|
||||||
@ -511,7 +514,12 @@ bool output_score(struct whisper_context * ctx, const char * fname, const whispe
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
bool output_json(
|
||||||
|
struct whisper_context * ctx,
|
||||||
|
const char * fname,
|
||||||
|
const whisper_params & params,
|
||||||
|
std::vector<std::vector<float>> pcmf32s,
|
||||||
|
bool full) {
|
||||||
std::ofstream fout(fname);
|
std::ofstream fout(fname);
|
||||||
int indent = 0;
|
int indent = 0;
|
||||||
|
|
||||||
@ -528,7 +536,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
|||||||
auto end_arr = [&](bool end) {
|
auto end_arr = [&](bool end) {
|
||||||
indent--;
|
indent--;
|
||||||
doindent();
|
doindent();
|
||||||
fout << (end ? "]\n" : "},\n");
|
fout << (end ? "]\n" : "],\n");
|
||||||
};
|
};
|
||||||
|
|
||||||
auto start_obj = [&](const char *name) {
|
auto start_obj = [&](const char *name) {
|
||||||
@ -569,12 +577,29 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
|||||||
end_value(end);
|
end_value(end);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto value_f = [&](const char *name, const float val, bool end) {
|
||||||
|
start_value(name);
|
||||||
|
fout << val;
|
||||||
|
end_value(end);
|
||||||
|
};
|
||||||
|
|
||||||
auto value_b = [&](const char *name, const bool val, bool end) {
|
auto value_b = [&](const char *name, const bool val, bool end) {
|
||||||
start_value(name);
|
start_value(name);
|
||||||
fout << (val ? "true" : "false");
|
fout << (val ? "true" : "false");
|
||||||
end_value(end);
|
end_value(end);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto times_o = [&](int64_t t0, int64_t t1, bool end) {
|
||||||
|
start_obj("timestamps");
|
||||||
|
value_s("from", to_timestamp(t0, true).c_str(), false);
|
||||||
|
value_s("to", to_timestamp(t1, true).c_str(), true);
|
||||||
|
end_obj(false);
|
||||||
|
start_obj("offsets");
|
||||||
|
value_i("from", t0 * 10, false);
|
||||||
|
value_i("to", t1 * 10, true);
|
||||||
|
end_obj(end);
|
||||||
|
};
|
||||||
|
|
||||||
if (!fout.is_open()) {
|
if (!fout.is_open()) {
|
||||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||||
return false;
|
return false;
|
||||||
@ -620,15 +645,26 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
|||||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||||
|
|
||||||
start_obj(nullptr);
|
start_obj(nullptr);
|
||||||
start_obj("timestamps");
|
times_o(t0, t1, false);
|
||||||
value_s("from", to_timestamp(t0, true).c_str(), false);
|
value_s("text", text, !params.diarize && !params.tinydiarize && !full);
|
||||||
value_s("to", to_timestamp(t1, true).c_str(), true);
|
|
||||||
end_obj(false);
|
if (full) {
|
||||||
start_obj("offsets");
|
start_arr("tokens");
|
||||||
value_i("from", t0 * 10, false);
|
const int n = whisper_full_n_tokens(ctx, i);
|
||||||
value_i("to", t1 * 10, true);
|
for (int j = 0; j < n; ++j) {
|
||||||
end_obj(false);
|
auto token = whisper_full_get_token_data(ctx, i, j);
|
||||||
value_s("text", text, !params.diarize && !params.tinydiarize);
|
start_obj(nullptr);
|
||||||
|
value_s("text", whisper_token_to_str(ctx, token.id), false);
|
||||||
|
if(token.t0 > -1 && token.t1 > -1) {
|
||||||
|
// If we have per-token timestamps, write them out
|
||||||
|
times_o(token.t0, token.t1, false);
|
||||||
|
}
|
||||||
|
value_i("id", token.id, false);
|
||||||
|
value_f("p", token.p, true);
|
||||||
|
end_obj(j == (n - 1));
|
||||||
|
}
|
||||||
|
end_arr(!params.diarize && !params.tinydiarize);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.diarize && pcmf32s.size() == 2) {
|
if (params.diarize && pcmf32s.size() == 2) {
|
||||||
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
|
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
|
||||||
@ -912,7 +948,7 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.offset_ms = params.offset_t_ms;
|
wparams.offset_ms = params.offset_t_ms;
|
||||||
wparams.duration_ms = params.duration_ms;
|
wparams.duration_ms = params.duration_ms;
|
||||||
|
|
||||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
|
||||||
wparams.thold_pt = params.word_thold;
|
wparams.thold_pt = params.word_thold;
|
||||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||||
wparams.split_on_word = params.split_on_word;
|
wparams.split_on_word = params.split_on_word;
|
||||||
@ -944,8 +980,9 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.progress_callback_user_data = &user_data;
|
wparams.progress_callback_user_data = &user_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// example for abort mechanism
|
// examples for abort mechanism
|
||||||
// in this example, we do not abort the processing, but we could if the flag is set to true
|
// in examples below, we do not abort the processing, but we could if the flag is set to true
|
||||||
|
|
||||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||||
{
|
{
|
||||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||||
@ -957,6 +994,17 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// the callback is called before every computation - if it returns true, the computation is aborted
|
||||||
|
{
|
||||||
|
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||||
|
|
||||||
|
wparams.abort_callback = [](void * user_data) {
|
||||||
|
bool is_aborted = *(bool*)user_data;
|
||||||
|
return is_aborted;
|
||||||
|
};
|
||||||
|
wparams.abort_callback_user_data = &is_aborted;
|
||||||
|
}
|
||||||
|
|
||||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||||
return 10;
|
return 10;
|
||||||
@ -1000,7 +1048,7 @@ int main(int argc, char ** argv) {
|
|||||||
// output to JSON file
|
// output to JSON file
|
||||||
if (params.output_jsn) {
|
if (params.output_jsn) {
|
||||||
const auto fname_jsn = fname_out + ".json";
|
const auto fname_jsn = fname_out + ".json";
|
||||||
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
|
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output to LRC file
|
// output to LRC file
|
||||||
|
@ -39,6 +39,20 @@ brew install sdl2
|
|||||||
make stream
|
make stream
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
|
||||||
|
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
|
||||||
|
`examples/steam` means your compiler cannot find them and it gives an error it cannot find the file.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
whisper.cpp/examples/stream$ make stream
|
||||||
|
g++ stream.cpp -o stream
|
||||||
|
stream.cpp:6:10: fatal error: common/sdl.h: No such file or directory
|
||||||
|
6 | #include "common/sdl.h"
|
||||||
|
| ^~~~~~~~~~~~~~
|
||||||
|
compilation terminated.
|
||||||
|
make: *** [<builtin>: stream] Error 1
|
||||||
|
```
|
||||||
|
|
||||||
## Web version
|
## Web version
|
||||||
|
|
||||||
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)
|
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)
|
||||||
|
@ -2,9 +2,8 @@
|
|||||||
//
|
//
|
||||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "common-sdl.h"
|
#include "common-sdl.h"
|
||||||
|
#include "common.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -14,6 +13,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
|
|
||||||
// 500 -> 00:05.000
|
// 500 -> 00:05.000
|
||||||
// 6000 -> 01:00.000
|
// 6000 -> 01:00.000
|
||||||
std::string to_timestamp(int64_t t) {
|
std::string to_timestamp(int64_t t) {
|
||||||
@ -52,6 +52,7 @@ struct whisper_params {
|
|||||||
std::string language = "en";
|
std::string language = "en";
|
||||||
std::string model = "models/ggml-base.en.bin";
|
std::string model = "models/ggml-base.en.bin";
|
||||||
std::string fname_out;
|
std::string fname_out;
|
||||||
|
bool save_audio = false; // save audio to wav file
|
||||||
};
|
};
|
||||||
|
|
||||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||||
@ -82,6 +83,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||||
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
||||||
|
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
||||||
|
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
@ -117,6 +119,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||||
|
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -154,7 +157,6 @@ int main(int argc, char ** argv) {
|
|||||||
audio.resume();
|
audio.resume();
|
||||||
|
|
||||||
// whisper init
|
// whisper init
|
||||||
|
|
||||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
|
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
|
||||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||||
whisper_print_usage(argc, argv, params);
|
whisper_print_usage(argc, argv, params);
|
||||||
@ -212,7 +214,18 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("[Start speaking]");
|
wav_writer wavWriter;
|
||||||
|
// save wav file
|
||||||
|
if (params.save_audio) {
|
||||||
|
// Get current date/time for filename
|
||||||
|
time_t now = time(0);
|
||||||
|
char buffer[80];
|
||||||
|
strftime(buffer, sizeof(buffer), "%Y%m%d%H%M%S", localtime(&now));
|
||||||
|
std::string filename = std::string(buffer) + ".wav";
|
||||||
|
|
||||||
|
wavWriter.open(filename, WHISPER_SAMPLE_RATE, 16, 1);
|
||||||
|
}
|
||||||
|
printf("[Start speaking]\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
auto t_last = std::chrono::high_resolution_clock::now();
|
auto t_last = std::chrono::high_resolution_clock::now();
|
||||||
@ -220,6 +233,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// main audio loop
|
// main audio loop
|
||||||
while (is_running) {
|
while (is_running) {
|
||||||
|
if (params.save_audio) {
|
||||||
|
wavWriter.write(pcmf32_new.data(), pcmf32_new.size());
|
||||||
|
}
|
||||||
// handle Ctrl + C
|
// handle Ctrl + C
|
||||||
is_running = sdl_poll_events();
|
is_running = sdl_poll_events();
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ if (WHISPER_SDL2)
|
|||||||
|
|
||||||
# TODO: this is temporary
|
# TODO: this is temporary
|
||||||
# need to export ggml symbols for MSVC, but too lazy ..
|
# need to export ggml symbols for MSVC, but too lazy ..
|
||||||
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
|
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp)
|
||||||
|
|
||||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
@ -2,6 +2,12 @@
|
|||||||
|
|
||||||
Talk with an LLaMA AI in your terminal
|
Talk with an LLaMA AI in your terminal
|
||||||
|
|
||||||
|
*Latest perf as of 2 Nov 2023 using Whisper Medium + LLaMA v2 13B Q8_0 on M2 Ultra:*
|
||||||
|
|
||||||
|
https://github.com/ggerganov/whisper.cpp/assets/1991296/d97a3788-bf2a-4756-9a43-60c6b391649e
|
||||||
|
|
||||||
|
*Previous demo running on CPUs*
|
||||||
|
|
||||||
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
|
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
|
||||||
|
|
||||||
## Building
|
## Building
|
||||||
@ -19,7 +25,7 @@ brew install sdl2
|
|||||||
make talk-llama
|
make talk-llama
|
||||||
|
|
||||||
# Run it
|
# Run it
|
||||||
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
|
||||||
```
|
```
|
||||||
|
|
||||||
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
|
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
|
||||||
@ -36,7 +42,7 @@ This feature is especially helpful for maintaining context in long conversations
|
|||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
|
||||||
```
|
```
|
||||||
|
|
||||||
## TTS
|
## TTS
|
||||||
|
@ -1,474 +0,0 @@
|
|||||||
// Internal header to be included only by llama.cpp.
|
|
||||||
// Contains wrappers around OS interfaces.
|
|
||||||
|
|
||||||
#ifndef LLAMA_UTIL_H
|
|
||||||
#define LLAMA_UTIL_H
|
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <cerrno>
|
|
||||||
#include <cstring>
|
|
||||||
#include <cstdarg>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <climits>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
#ifdef __has_include
|
|
||||||
#if __has_include(<unistd.h>)
|
|
||||||
#include <unistd.h>
|
|
||||||
#if defined(_POSIX_MAPPED_FILES)
|
|
||||||
#include <sys/mman.h>
|
|
||||||
#endif
|
|
||||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
|
||||||
#include <sys/resource.h>
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
#define WIN32_LEAN_AND_MEAN
|
|
||||||
#ifndef NOMINMAX
|
|
||||||
#define NOMINMAX
|
|
||||||
#endif
|
|
||||||
#include <windows.h>
|
|
||||||
#include <io.h>
|
|
||||||
#include <stdio.h> // for _fseeki64
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define LLAMA_ASSERT(x) \
|
|
||||||
do { \
|
|
||||||
if (!(x)) { \
|
|
||||||
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
||||||
abort(); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#ifdef __GNUC__
|
|
||||||
#ifdef __MINGW32__
|
|
||||||
__attribute__((format(gnu_printf, 1, 2)))
|
|
||||||
#else
|
|
||||||
__attribute__((format(printf, 1, 2)))
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
static std::string format(const char * fmt, ...) {
|
|
||||||
va_list ap, ap2;
|
|
||||||
va_start(ap, fmt);
|
|
||||||
va_copy(ap2, ap);
|
|
||||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
|
||||||
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
|
|
||||||
std::vector<char> buf(size + 1);
|
|
||||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
|
||||||
LLAMA_ASSERT(size2 == size);
|
|
||||||
va_end(ap2);
|
|
||||||
va_end(ap);
|
|
||||||
return std::string(buf.data(), size);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_file {
|
|
||||||
// use FILE * so we don't have to re-open the file to mmap
|
|
||||||
FILE * fp;
|
|
||||||
size_t size;
|
|
||||||
|
|
||||||
llama_file(const char * fname, const char * mode) {
|
|
||||||
fp = std::fopen(fname, mode);
|
|
||||||
if (fp == NULL) {
|
|
||||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
|
||||||
}
|
|
||||||
seek(0, SEEK_END);
|
|
||||||
size = tell();
|
|
||||||
seek(0, SEEK_SET);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t tell() const {
|
|
||||||
#ifdef _WIN32
|
|
||||||
__int64 ret = _ftelli64(fp);
|
|
||||||
#else
|
|
||||||
long ret = std::ftell(fp);
|
|
||||||
#endif
|
|
||||||
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
|
|
||||||
return (size_t) ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
void seek(size_t offset, int whence) {
|
|
||||||
#ifdef _WIN32
|
|
||||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
|
||||||
#else
|
|
||||||
int ret = std::fseek(fp, (long) offset, whence);
|
|
||||||
#endif
|
|
||||||
LLAMA_ASSERT(ret == 0); // same
|
|
||||||
}
|
|
||||||
|
|
||||||
void read_raw(void * ptr, size_t len) const {
|
|
||||||
if (len == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
errno = 0;
|
|
||||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
|
||||||
if (ferror(fp)) {
|
|
||||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
|
||||||
}
|
|
||||||
if (ret != 1) {
|
|
||||||
throw std::runtime_error(std::string("unexpectedly reached end of file"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::uint32_t read_u32() {
|
|
||||||
std::uint32_t ret;
|
|
||||||
read_raw(&ret, sizeof(ret));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string read_string(std::uint32_t len) {
|
|
||||||
std::vector<char> chars(len);
|
|
||||||
read_raw(chars.data(), len);
|
|
||||||
return std::string(chars.data(), len);
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_raw(const void * ptr, size_t len) const {
|
|
||||||
if (len == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
errno = 0;
|
|
||||||
size_t ret = std::fwrite(ptr, len, 1, fp);
|
|
||||||
if (ret != 1) {
|
|
||||||
throw std::runtime_error(format("write error: %s", strerror(errno)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void write_u32(std::uint32_t val) {
|
|
||||||
write_raw(&val, sizeof(val));
|
|
||||||
}
|
|
||||||
|
|
||||||
~llama_file() {
|
|
||||||
if (fp) {
|
|
||||||
std::fclose(fp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
|
||||||
static std::string llama_format_win_err(DWORD err) {
|
|
||||||
LPSTR buf;
|
|
||||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
|
||||||
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
|
||||||
if (!size) {
|
|
||||||
return "FormatMessageA failed";
|
|
||||||
}
|
|
||||||
std::string ret(buf, size);
|
|
||||||
LocalFree(buf);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
struct llama_mmap {
|
|
||||||
void * addr;
|
|
||||||
size_t size;
|
|
||||||
|
|
||||||
llama_mmap(const llama_mmap &) = delete;
|
|
||||||
|
|
||||||
#ifdef _POSIX_MAPPED_FILES
|
|
||||||
static constexpr bool SUPPORTED = true;
|
|
||||||
|
|
||||||
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
|
|
||||||
size = file->size;
|
|
||||||
int fd = fileno(file->fp);
|
|
||||||
int flags = MAP_SHARED;
|
|
||||||
#ifdef __linux__
|
|
||||||
flags |= MAP_POPULATE;
|
|
||||||
#endif
|
|
||||||
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
|
|
||||||
if (addr == MAP_FAILED) {
|
|
||||||
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (prefetch > 0) {
|
|
||||||
// Advise the kernel to preload the mapped memory
|
|
||||||
if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
|
|
||||||
fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
|
|
||||||
strerror(errno));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~llama_mmap() {
|
|
||||||
munmap(addr, size);
|
|
||||||
}
|
|
||||||
#elif defined(_WIN32)
|
|
||||||
static constexpr bool SUPPORTED = true;
|
|
||||||
|
|
||||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
|
||||||
size = file->size;
|
|
||||||
|
|
||||||
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
|
|
||||||
|
|
||||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
|
||||||
DWORD error = GetLastError();
|
|
||||||
|
|
||||||
if (hMapping == NULL) {
|
|
||||||
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
|
|
||||||
}
|
|
||||||
|
|
||||||
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
|
||||||
error = GetLastError();
|
|
||||||
CloseHandle(hMapping);
|
|
||||||
|
|
||||||
if (addr == NULL) {
|
|
||||||
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
|
||||||
if (prefetch) {
|
|
||||||
// Advise the kernel to preload the mapped memory
|
|
||||||
WIN32_MEMORY_RANGE_ENTRY range;
|
|
||||||
range.VirtualAddress = addr;
|
|
||||||
range.NumberOfBytes = (SIZE_T)size;
|
|
||||||
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
|
||||||
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
|
|
||||||
llama_format_win_err(GetLastError()).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
|
|
||||||
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
|
||||||
}
|
|
||||||
|
|
||||||
~llama_mmap() {
|
|
||||||
if (!UnmapViewOfFile(addr)) {
|
|
||||||
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
|
|
||||||
llama_format_win_err(GetLastError()).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
static constexpr bool SUPPORTED = false;
|
|
||||||
|
|
||||||
llama_mmap(struct llama_file *, bool prefetch = true) {
|
|
||||||
(void)prefetch;
|
|
||||||
throw std::runtime_error(std::string("mmap not supported"));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
// Represents some region of memory being locked using mlock or VirtualLock;
|
|
||||||
// will automatically unlock on destruction.
|
|
||||||
struct llama_mlock {
|
|
||||||
void * addr = NULL;
|
|
||||||
size_t size = 0;
|
|
||||||
bool failed_already = false;
|
|
||||||
|
|
||||||
llama_mlock() {}
|
|
||||||
llama_mlock(const llama_mlock &) = delete;
|
|
||||||
|
|
||||||
~llama_mlock() {
|
|
||||||
if (size) {
|
|
||||||
raw_unlock(addr, size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void init(void * ptr) {
|
|
||||||
LLAMA_ASSERT(addr == NULL && size == 0);
|
|
||||||
addr = ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void grow_to(size_t target_size) {
|
|
||||||
LLAMA_ASSERT(addr);
|
|
||||||
if (failed_already) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
size_t granularity = lock_granularity();
|
|
||||||
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
|
||||||
if (target_size > size) {
|
|
||||||
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
|
||||||
size = target_size;
|
|
||||||
} else {
|
|
||||||
failed_already = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef _POSIX_MEMLOCK_RANGE
|
|
||||||
static constexpr bool SUPPORTED = true;
|
|
||||||
|
|
||||||
size_t lock_granularity() {
|
|
||||||
return (size_t) sysconf(_SC_PAGESIZE);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __APPLE__
|
|
||||||
#define MLOCK_SUGGESTION \
|
|
||||||
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
|
||||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
|
||||||
#else
|
|
||||||
#define MLOCK_SUGGESTION \
|
|
||||||
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool raw_lock(const void * addr, size_t size) {
|
|
||||||
if (!mlock(addr, size)) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
char* errmsg = std::strerror(errno);
|
|
||||||
bool suggest = (errno == ENOMEM);
|
|
||||||
|
|
||||||
// Check if the resource limit is fine after all
|
|
||||||
struct rlimit lock_limit;
|
|
||||||
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
|
|
||||||
suggest = false;
|
|
||||||
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
|
|
||||||
suggest = false;
|
|
||||||
|
|
||||||
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
|
||||||
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef MLOCK_SUGGESTION
|
|
||||||
|
|
||||||
void raw_unlock(void * addr, size_t size) {
|
|
||||||
if (munlock(addr, size)) {
|
|
||||||
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#elif defined(_WIN32)
|
|
||||||
static constexpr bool SUPPORTED = true;
|
|
||||||
|
|
||||||
size_t lock_granularity() {
|
|
||||||
SYSTEM_INFO si;
|
|
||||||
GetSystemInfo(&si);
|
|
||||||
return (size_t) si.dwPageSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool raw_lock(void * ptr, size_t len) {
|
|
||||||
for (int tries = 1; ; tries++) {
|
|
||||||
if (VirtualLock(ptr, len)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (tries == 2) {
|
|
||||||
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
|
||||||
len, size, llama_format_win_err(GetLastError()).c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// It failed but this was only the first try; increase the working
|
|
||||||
// set size and try again.
|
|
||||||
SIZE_T min_ws_size, max_ws_size;
|
|
||||||
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
|
||||||
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
|
|
||||||
llama_format_win_err(GetLastError()).c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// Per MSDN: "The maximum number of pages that a process can lock
|
|
||||||
// is equal to the number of pages in its minimum working set minus
|
|
||||||
// a small overhead."
|
|
||||||
// Hopefully a megabyte is enough overhead:
|
|
||||||
size_t increment = len + 1048576;
|
|
||||||
// The minimum must be <= the maximum, so we need to increase both:
|
|
||||||
min_ws_size += increment;
|
|
||||||
max_ws_size += increment;
|
|
||||||
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
|
||||||
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
|
|
||||||
llama_format_win_err(GetLastError()).c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void raw_unlock(void * ptr, size_t len) {
|
|
||||||
if (!VirtualUnlock(ptr, len)) {
|
|
||||||
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
|
|
||||||
llama_format_win_err(GetLastError()).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
static constexpr bool SUPPORTED = false;
|
|
||||||
|
|
||||||
size_t lock_granularity() {
|
|
||||||
return (size_t) 65536;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool raw_lock(const void * addr, size_t len) {
|
|
||||||
fprintf(stderr, "warning: mlock not supported on this system\n");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void raw_unlock(const void * addr, size_t len) {}
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
|
||||||
struct llama_buffer {
|
|
||||||
uint8_t * addr = NULL;
|
|
||||||
size_t size = 0;
|
|
||||||
|
|
||||||
llama_buffer() = default;
|
|
||||||
|
|
||||||
void resize(size_t len) {
|
|
||||||
delete[] addr;
|
|
||||||
addr = new uint8_t[len];
|
|
||||||
size = len;
|
|
||||||
}
|
|
||||||
|
|
||||||
~llama_buffer() {
|
|
||||||
delete[] addr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// disable copy and move
|
|
||||||
llama_buffer(const llama_buffer&) = delete;
|
|
||||||
llama_buffer(llama_buffer&&) = delete;
|
|
||||||
llama_buffer& operator=(const llama_buffer&) = delete;
|
|
||||||
llama_buffer& operator=(llama_buffer&&) = delete;
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
|
||||||
#include "ggml-cuda.h"
|
|
||||||
struct llama_ctx_buffer {
|
|
||||||
uint8_t * addr = NULL;
|
|
||||||
bool is_cuda;
|
|
||||||
size_t size = 0;
|
|
||||||
|
|
||||||
llama_ctx_buffer() = default;
|
|
||||||
|
|
||||||
void resize(size_t size) {
|
|
||||||
free();
|
|
||||||
|
|
||||||
addr = (uint8_t *) ggml_cuda_host_malloc(size);
|
|
||||||
if (addr) {
|
|
||||||
is_cuda = true;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// fall back to pageable memory
|
|
||||||
addr = new uint8_t[size];
|
|
||||||
is_cuda = false;
|
|
||||||
}
|
|
||||||
this->size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free() {
|
|
||||||
if (addr) {
|
|
||||||
if (is_cuda) {
|
|
||||||
ggml_cuda_host_free(addr);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
delete[] addr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
addr = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
~llama_ctx_buffer() {
|
|
||||||
free();
|
|
||||||
}
|
|
||||||
|
|
||||||
// disable copy and move
|
|
||||||
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
|
|
||||||
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
|
|
||||||
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
|
|
||||||
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
|
|
||||||
};
|
|
||||||
#else
|
|
||||||
typedef llama_buffer llama_ctx_buffer;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif
|
|
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,16 @@
|
|||||||
#ifndef LLAMA_H
|
#ifndef LLAMA_H
|
||||||
#define LLAMA_H
|
#define LLAMA_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||||
|
#else
|
||||||
|
#define LLAMA_MAX_DEVICES 1
|
||||||
|
#endif // GGML_USE_CUBLAS
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
#ifdef LLAMA_SHARED
|
#ifdef LLAMA_SHARED
|
||||||
@ -19,18 +27,26 @@
|
|||||||
# define LLAMA_API
|
# define LLAMA_API
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
|
#ifdef __GNUC__
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
# define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
||||||
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
|
#elif defined(_MSC_VER)
|
||||||
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
|
# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
||||||
|
#else
|
||||||
|
# define DEPRECATED(func, hint) func
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_FILE_VERSION 3
|
|
||||||
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
|
|
||||||
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 1
|
#define LLAMA_SESSION_VERSION 1
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||||
|
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||||
|
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
@ -41,10 +57,57 @@ extern "C" {
|
|||||||
// TODO: show sample usage
|
// TODO: show sample usage
|
||||||
//
|
//
|
||||||
|
|
||||||
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
|
||||||
typedef int llama_token;
|
typedef int llama_token;
|
||||||
|
|
||||||
|
enum llama_log_level {
|
||||||
|
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||||
|
LLAMA_LOG_LEVEL_WARN = 3,
|
||||||
|
LLAMA_LOG_LEVEL_INFO = 4
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llama_vocab_type {
|
||||||
|
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||||
|
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llama_token_type {
|
||||||
|
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
||||||
|
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
||||||
|
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
|
||||||
|
LLAMA_TOKEN_TYPE_CONTROL = 3,
|
||||||
|
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
|
||||||
|
LLAMA_TOKEN_TYPE_UNUSED = 5,
|
||||||
|
LLAMA_TOKEN_TYPE_BYTE = 6,
|
||||||
|
};
|
||||||
|
|
||||||
|
// model file types
|
||||||
|
enum llama_ftype {
|
||||||
|
LLAMA_FTYPE_ALL_F32 = 0,
|
||||||
|
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||||
|
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
||||||
|
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||||
|
|
||||||
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
@ -60,67 +123,152 @@ extern "C" {
|
|||||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||||
|
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
int n_ctx; // text context
|
uint32_t seed; // RNG seed, -1 for random
|
||||||
int n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_ctx; // text context
|
||||||
int seed; // RNG seed, -1 for random
|
int32_t n_batch; // prompt processing batch size
|
||||||
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
|
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||||
|
|
||||||
|
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||||
|
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||||
|
float rope_freq_base; // RoPE base frequency
|
||||||
|
float rope_freq_scale; // RoPE frequency scaling factor
|
||||||
|
|
||||||
|
// called with a progress value between 0 and 1, pass NULL to disable
|
||||||
|
llama_progress_callback progress_callback;
|
||||||
|
// context pointer passed to the progress callback
|
||||||
|
void * progress_callback_user_data;
|
||||||
|
|
||||||
|
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||||
|
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||||
|
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
bool use_mmap; // use mmap if possible
|
bool use_mmap; // use mmap if possible
|
||||||
bool use_mlock; // force system to keep model in RAM
|
bool use_mlock; // force system to keep model in RAM
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
|
|
||||||
// called with a progress value between 0 and 1, pass NULL to disable
|
|
||||||
llama_progress_callback progress_callback;
|
|
||||||
// context pointer passed to the progress callback
|
|
||||||
void * progress_callback_user_data;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// model file types
|
// Signature for logging events
|
||||||
enum llama_ftype {
|
// Note that text includes the new line character at the end for most events.
|
||||||
LLAMA_FTYPE_ALL_F32 = 0,
|
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
// if it exists.
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
// It might not exist for progress report where '.' is output repeatedly.
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
|
||||||
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
// model quantization parameters
|
||||||
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
typedef struct llama_model_quantize_params {
|
||||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||||
|
bool quantize_output_tensor; // quantize output.weight
|
||||||
|
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||||
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
|
// grammar types
|
||||||
|
struct llama_grammar;
|
||||||
|
|
||||||
|
// grammar element type
|
||||||
|
enum llama_gretype {
|
||||||
|
// end of rule definition
|
||||||
|
LLAMA_GRETYPE_END = 0,
|
||||||
|
|
||||||
|
// start of alternate definition for rule
|
||||||
|
LLAMA_GRETYPE_ALT = 1,
|
||||||
|
|
||||||
|
// non-terminal element: reference to rule
|
||||||
|
LLAMA_GRETYPE_RULE_REF = 2,
|
||||||
|
|
||||||
|
// terminal element: character (code point)
|
||||||
|
LLAMA_GRETYPE_CHAR = 3,
|
||||||
|
|
||||||
|
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||||
|
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||||
|
// be an inclusive range ([a-z])
|
||||||
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||||
|
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||||
|
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
typedef struct llama_grammar_element {
|
||||||
|
enum llama_gretype type;
|
||||||
|
uint32_t value; // Unicode code point or rule ID
|
||||||
|
} llama_grammar_element;
|
||||||
|
|
||||||
LLAMA_API bool llama_mmap_supported();
|
// performance timing information
|
||||||
LLAMA_API bool llama_mlock_supported();
|
struct llama_timings {
|
||||||
|
double t_start_ms;
|
||||||
|
double t_end_ms;
|
||||||
|
double t_load_ms;
|
||||||
|
double t_sample_ms;
|
||||||
|
double t_p_eval_ms;
|
||||||
|
double t_eval_ms;
|
||||||
|
|
||||||
|
int32_t n_sample;
|
||||||
|
int32_t n_p_eval;
|
||||||
|
int32_t n_eval;
|
||||||
|
};
|
||||||
|
|
||||||
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
|
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||||
|
|
||||||
// TODO: not great API - very likely to change
|
|
||||||
// Initialize the llama + ggml backend
|
// Initialize the llama + ggml backend
|
||||||
|
// If numa is true, use NUMA optimizations
|
||||||
// Call once at the start of the program
|
// Call once at the start of the program
|
||||||
LLAMA_API void llama_init_backend();
|
LLAMA_API void llama_backend_init(bool numa);
|
||||||
|
|
||||||
LLAMA_API int64_t llama_time_us();
|
// Call once at the end of the program - currently only used for MPI
|
||||||
|
LLAMA_API void llama_backend_free(void);
|
||||||
|
|
||||||
// Various functions for loading a ggml llama model.
|
LLAMA_API struct llama_model * llama_load_model_from_file(
|
||||||
// Allocate (almost) all memory needed for the model.
|
|
||||||
// Return NULL on failure
|
|
||||||
LLAMA_API struct llama_context * llama_init_from_file(
|
|
||||||
const char * path_model,
|
const char * path_model,
|
||||||
struct llama_context_params params);
|
struct llama_context_params params);
|
||||||
|
|
||||||
|
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||||
|
struct llama_model * model,
|
||||||
|
struct llama_context_params params);
|
||||||
|
|
||||||
// Frees all allocated memory
|
// Frees all allocated memory
|
||||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||||
|
|
||||||
// TODO: not great API - very likely to change
|
LLAMA_API int64_t llama_time_us(void);
|
||||||
|
|
||||||
|
LLAMA_API int llama_max_devices (void);
|
||||||
|
LLAMA_API bool llama_mmap_supported (void);
|
||||||
|
LLAMA_API bool llama_mlock_supported(void);
|
||||||
|
|
||||||
|
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
|
||||||
|
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||||
|
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
|
||||||
|
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||||
|
|
||||||
|
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
|
||||||
|
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||||
|
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
|
||||||
|
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||||
|
|
||||||
|
// Get a string describing the model type
|
||||||
|
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||||
|
// Returns the total size of all the tensors in the model in bytes
|
||||||
|
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||||
|
// Returns the total number of parameters in the model
|
||||||
|
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
|
||||||
LLAMA_API int llama_model_quantize(
|
LLAMA_API int llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
const char * fname_out,
|
const char * fname_out,
|
||||||
enum llama_ftype ftype,
|
const llama_model_quantize_params * params);
|
||||||
int nthread);
|
|
||||||
|
|
||||||
// Apply a LoRA adapter to a loaded model
|
// Apply a LoRA adapter to a loaded model
|
||||||
// path_base_model is the path to a higher quality model to use as a base for
|
// path_base_model is the path to a higher quality model to use as a base for
|
||||||
@ -128,8 +276,15 @@ extern "C" {
|
|||||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||||
// will be applied on top of the previous one
|
// will be applied on top of the previous one
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API int llama_apply_lora_from_file(
|
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
|
const char * path_lora,
|
||||||
|
const char * path_base_model,
|
||||||
|
int n_threads),
|
||||||
|
"please use llama_model_apply_lora_from_file instead");
|
||||||
|
|
||||||
|
LLAMA_API int llama_model_apply_lora_from_file(
|
||||||
|
const struct llama_model * model,
|
||||||
const char * path_lora,
|
const char * path_lora,
|
||||||
const char * path_base_model,
|
const char * path_base_model,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
@ -138,7 +293,7 @@ extern "C" {
|
|||||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Sets the current rng seed.
|
// Sets the current rng seed.
|
||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||||
|
|
||||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
// and kv_cache) - will often be smaller after compacting tokens
|
// and kv_cache) - will often be smaller after compacting tokens
|
||||||
@ -168,21 +323,19 @@ extern "C" {
|
|||||||
int n_past,
|
int n_past,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
|
|
||||||
// Convert the provided text into tokens.
|
// Same as llama_eval, but use float matrix input directly.
|
||||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
LLAMA_API int llama_eval_embd(
|
||||||
// Returns the number of tokens on success, no more than n_max_tokens
|
|
||||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
|
||||||
// TODO: not sure if correct
|
|
||||||
LLAMA_API int llama_tokenize(
|
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const char * text,
|
const float * embd,
|
||||||
llama_token * tokens,
|
int n_tokens,
|
||||||
int n_max_tokens,
|
int n_past,
|
||||||
bool add_bos);
|
int n_threads);
|
||||||
|
|
||||||
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
|
// Export a static computation graph for context of 511 and batch size of 1
|
||||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
||||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
// parameters here to keep things simple
|
||||||
|
// IMPORTANT: do not use for anything else other than debugging and testing!
|
||||||
|
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_eval()
|
// Token logits obtained from the last call to llama_eval()
|
||||||
// The logits for the last token are stored in the last row
|
// The logits for the last token are stored in the last row
|
||||||
@ -195,15 +348,75 @@ extern "C" {
|
|||||||
// shape: [n_embd] (1-dimensional)
|
// shape: [n_embd] (1-dimensional)
|
||||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Token Id -> String. Uses the vocabulary in the provided context
|
//
|
||||||
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
|
// Vocab
|
||||||
|
//
|
||||||
|
|
||||||
|
LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token);
|
||||||
|
|
||||||
|
LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
|
||||||
|
|
||||||
|
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
|
||||||
|
|
||||||
// Special tokens
|
// Special tokens
|
||||||
LLAMA_API llama_token llama_token_bos();
|
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
|
||||||
LLAMA_API llama_token llama_token_eos();
|
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
|
||||||
LLAMA_API llama_token llama_token_nl();
|
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
|
||||||
|
|
||||||
|
//
|
||||||
|
// Tokenization
|
||||||
|
//
|
||||||
|
|
||||||
|
// Convert the provided text into tokens.
|
||||||
|
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
|
// Returns the number of tokens on success, no more than n_max_tokens
|
||||||
|
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||||
|
LLAMA_API int llama_tokenize(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * text,
|
||||||
|
llama_token * tokens,
|
||||||
|
int n_max_tokens,
|
||||||
|
bool add_bos);
|
||||||
|
|
||||||
|
LLAMA_API int llama_tokenize_with_model(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const char * text,
|
||||||
|
llama_token * tokens,
|
||||||
|
int n_max_tokens,
|
||||||
|
bool add_bos);
|
||||||
|
|
||||||
|
// Token Id -> Piece.
|
||||||
|
// Uses the vocabulary in the provided context.
|
||||||
|
// Does not write null terminator to the buffer.
|
||||||
|
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||||
|
LLAMA_API int llama_token_to_piece(
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
llama_token token,
|
||||||
|
char * buf,
|
||||||
|
int length);
|
||||||
|
|
||||||
|
LLAMA_API int llama_token_to_piece_with_model(
|
||||||
|
const struct llama_model * model,
|
||||||
|
llama_token token,
|
||||||
|
char * buf,
|
||||||
|
int length);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Grammar
|
||||||
|
//
|
||||||
|
|
||||||
|
LLAMA_API struct llama_grammar * llama_grammar_init(
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index);
|
||||||
|
|
||||||
|
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
//
|
||||||
// Sampling functions
|
// Sampling functions
|
||||||
|
//
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||||
@ -211,6 +424,16 @@ extern "C" {
|
|||||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||||
|
|
||||||
|
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||||
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||||
|
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||||
|
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||||
|
LLAMA_API void llama_sample_classifier_free_guidance(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_token_data_array * candidates,
|
||||||
|
struct llama_context * guidance_ctx,
|
||||||
|
float scale);
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||||
|
|
||||||
@ -227,6 +450,9 @@ extern "C" {
|
|||||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||||
|
|
||||||
|
/// @details Apply constraints from grammar
|
||||||
|
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
@ -248,13 +474,60 @@ extern "C" {
|
|||||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
/// @details Accepts the sampled token into the grammar
|
||||||
|
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Beam search
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_beam_view {
|
||||||
|
const llama_token * tokens;
|
||||||
|
size_t n_tokens;
|
||||||
|
float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||||
|
bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||||
|
};
|
||||||
|
|
||||||
|
// Passed to beam_search_callback function.
|
||||||
|
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
|
||||||
|
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
|
||||||
|
// These pointers are valid only during the synchronous callback, so should not be saved.
|
||||||
|
struct llama_beams_state {
|
||||||
|
struct llama_beam_view * beam_views;
|
||||||
|
size_t n_beams; // Number of elements in beam_views[].
|
||||||
|
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
|
||||||
|
bool last_call; // True iff this is the last callback invocation.
|
||||||
|
};
|
||||||
|
|
||||||
|
// Type of pointer to the beam_search_callback function.
|
||||||
|
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||||
|
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||||
|
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
|
||||||
|
|
||||||
|
/// @details Deterministically returns entire sentence constructed by a beam search.
|
||||||
|
/// @param ctx Pointer to the llama_context.
|
||||||
|
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
|
||||||
|
/// @param callback_data A pointer that is simply passed back to callback.
|
||||||
|
/// @param n_beams Number of beams to use.
|
||||||
|
/// @param n_past Number of tokens already evaluated.
|
||||||
|
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
|
||||||
|
/// @param n_threads Number of threads as passed to llama_eval().
|
||||||
|
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
|
||||||
|
|
||||||
// Performance information
|
// Performance information
|
||||||
|
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Print system information
|
// Print system information
|
||||||
LLAMA_API const char * llama_print_system_info(void);
|
LLAMA_API const char * llama_print_system_info(void);
|
||||||
|
|
||||||
|
// Set callback for all future logging events.
|
||||||
|
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||||
|
LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data);
|
||||||
|
|
||||||
|
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -264,10 +537,11 @@ extern "C" {
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
|
|
||||||
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||||
|
|
||||||
#endif
|
#endif // LLAMA_API_INTERNAL
|
||||||
|
|
||||||
#endif // LLAMA_H
|
#endif // LLAMA_H
|
||||||
|
0
examples/talk-llama/speak
Normal file → Executable file
0
examples/talk-llama/speak
Normal file → Executable file
@ -1,8 +1,8 @@
|
|||||||
// Talk with AI
|
// Talk with AI
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "common-sdl.h"
|
#include "common-sdl.h"
|
||||||
|
#include "common.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
@ -25,6 +25,20 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
||||||
|
std::vector<char> result(8, 0);
|
||||||
|
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
|
||||||
|
if (n_tokens < 0) {
|
||||||
|
result.resize(-n_tokens);
|
||||||
|
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
|
||||||
|
GGML_ASSERT(check == -n_tokens);
|
||||||
|
} else {
|
||||||
|
result.resize(n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::string(result.data(), result.size());
|
||||||
|
}
|
||||||
|
|
||||||
// command-line parameters
|
// command-line parameters
|
||||||
struct whisper_params {
|
struct whisper_params {
|
||||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
@ -235,7 +249,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// llama init
|
// llama init
|
||||||
|
|
||||||
llama_init_backend();
|
llama_backend_init(true);
|
||||||
|
|
||||||
auto lparams = llama_context_default_params();
|
auto lparams = llama_context_default_params();
|
||||||
|
|
||||||
@ -244,7 +258,9 @@ int main(int argc, char ** argv) {
|
|||||||
lparams.seed = 1;
|
lparams.seed = 1;
|
||||||
lparams.f16_kv = true;
|
lparams.f16_kv = true;
|
||||||
|
|
||||||
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
|
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lparams);
|
||||||
|
|
||||||
|
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lparams);
|
||||||
|
|
||||||
// print some info about the processing
|
// print some info about the processing
|
||||||
{
|
{
|
||||||
@ -267,7 +283,6 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// init audio
|
// init audio
|
||||||
|
|
||||||
audio_async audio(30*1000);
|
audio_async audio(30*1000);
|
||||||
@ -278,8 +293,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
audio.resume();
|
audio.resume();
|
||||||
|
|
||||||
int n_iter = 0;
|
|
||||||
|
|
||||||
bool is_running = true;
|
bool is_running = true;
|
||||||
bool force_speak = false;
|
bool force_speak = false;
|
||||||
|
|
||||||
@ -514,7 +527,7 @@ int main(int argc, char ** argv) {
|
|||||||
//printf("\n---\n");
|
//printf("\n---\n");
|
||||||
//printf("resetting: '");
|
//printf("resetting: '");
|
||||||
//for (int i = 0; i < (int) embd.size(); i++) {
|
//for (int i = 0; i < (int) embd.size(); i++) {
|
||||||
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
|
// printf("%s", llama_token_to_piece(ctx_llama, embd[i]));
|
||||||
//}
|
//}
|
||||||
//printf("'\n");
|
//printf("'\n");
|
||||||
//printf("\n---\n");
|
//printf("\n---\n");
|
||||||
@ -582,7 +595,7 @@ int main(int argc, char ** argv) {
|
|||||||
auto logits = llama_get_logits(ctx_llama);
|
auto logits = llama_get_logits(ctx_llama);
|
||||||
auto n_vocab = llama_n_vocab(ctx_llama);
|
auto n_vocab = llama_n_vocab(ctx_llama);
|
||||||
|
|
||||||
logits[llama_token_eos()] = 0;
|
logits[llama_token_eos(ctx_llama)] = 0;
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
@ -593,13 +606,13 @@ int main(int argc, char ** argv) {
|
|||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
// apply repeat penalty
|
// apply repeat penalty
|
||||||
const float nl_logit = logits[llama_token_nl()];
|
const float nl_logit = logits[llama_token_nl(ctx_llama)];
|
||||||
|
|
||||||
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
|
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
|
||||||
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
||||||
repeat_last_n, repeat_penalty);
|
repeat_last_n, repeat_penalty);
|
||||||
|
|
||||||
logits[llama_token_nl()] = nl_logit;
|
logits[llama_token_nl(ctx_llama)] = nl_logit;
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// Greedy sampling
|
||||||
@ -613,22 +626,22 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (id != llama_token_eos()) {
|
if (id != llama_token_eos(ctx_llama)) {
|
||||||
// add it to the context
|
// add it to the context
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
text_to_speak += llama_token_to_str(ctx_llama, id);
|
text_to_speak += llama_token_to_piece(ctx_llama, id);
|
||||||
|
|
||||||
printf("%s", llama_token_to_str(ctx_llama, id));
|
printf("%s", llama_token_to_piece(ctx_llama, id).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
std::string last_output;
|
std::string last_output;
|
||||||
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
|
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
|
||||||
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
|
last_output += llama_token_to_piece(ctx_llama, embd_inp[i]);
|
||||||
}
|
}
|
||||||
last_output += llama_token_to_str(ctx_llama, embd[0]);
|
last_output += llama_token_to_piece(ctx_llama, embd[0]);
|
||||||
|
|
||||||
for (std::string & antiprompt : antiprompts) {
|
for (std::string & antiprompt : antiprompts) {
|
||||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||||
@ -649,11 +662,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||||
|
if (ret != 0) {
|
||||||
|
fprintf(stderr, "%s: failed to speak\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
audio.clear();
|
audio.clear();
|
||||||
|
|
||||||
++n_iter;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
|
|||||||
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin
|
||||||
```
|
```
|
||||||
|
|
||||||
## TTS
|
## TTS
|
||||||
|
@ -191,9 +191,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|||||||
// create the ggml context
|
// create the ggml context
|
||||||
{
|
{
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
.mem_size = ctx_size,
|
/*.mem_size =*/ ctx_size,
|
||||||
.mem_buffer = NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
.no_alloc = false,
|
/*.no_alloc =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
model.ctx = ggml_init(params);
|
model.ctx = ggml_init(params);
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
// Talk with AI
|
// Talk with AI
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "common-sdl.h"
|
#include "common-sdl.h"
|
||||||
|
#include "common.h"
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
#include "gpt-2.h"
|
#include "gpt-2.h"
|
||||||
|
|
||||||
@ -349,7 +349,10 @@ int main(int argc, char ** argv) {
|
|||||||
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
|
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
|
||||||
|
|
||||||
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
|
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
|
||||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||||
|
if (ret != 0) {
|
||||||
|
fprintf(stderr, "%s: system() failed!\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
audio.clear();
|
audio.clear();
|
||||||
|
|
||||||
|
2
examples/whisper.android/.idea/compiler.xml
generated
2
examples/whisper.android/.idea/compiler.xml
generated
@ -1,6 +1,6 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="CompilerConfiguration">
|
<component name="CompilerConfiguration">
|
||||||
<bytecodeTargetLevel target="11" />
|
<bytecodeTargetLevel target="17" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
4
examples/whisper.android/.idea/gradle.xml
generated
4
examples/whisper.android/.idea/gradle.xml
generated
@ -4,15 +4,15 @@
|
|||||||
<component name="GradleSettings">
|
<component name="GradleSettings">
|
||||||
<option name="linkedExternalProjectsSettings">
|
<option name="linkedExternalProjectsSettings">
|
||||||
<GradleProjectSettings>
|
<GradleProjectSettings>
|
||||||
<option name="testRunner" value="GRADLE" />
|
|
||||||
<option name="distributionType" value="DEFAULT_WRAPPED" />
|
|
||||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||||
|
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
|
||||||
<option name="modules">
|
<option name="modules">
|
||||||
<set>
|
<set>
|
||||||
<option value="$PROJECT_DIR$" />
|
<option value="$PROJECT_DIR$" />
|
||||||
<option value="$PROJECT_DIR$/app" />
|
<option value="$PROJECT_DIR$/app" />
|
||||||
</set>
|
</set>
|
||||||
</option>
|
</option>
|
||||||
|
<option name="resolveExternalAnnotations" value="false" />
|
||||||
</GradleProjectSettings>
|
</GradleProjectSettings>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
|
2
examples/whisper.android/.idea/misc.xml
generated
2
examples/whisper.android/.idea/misc.xml
generated
@ -1,7 +1,7 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="true" project-jdk-name="Android Studio default JDK" project-jdk-type="JavaSDK">
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" default="true" project-jdk-name="jbr-17" project-jdk-type="JavaSDK">
|
||||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectType">
|
<component name="ProjectType">
|
||||||
|
@ -5,12 +5,12 @@ plugins {
|
|||||||
|
|
||||||
android {
|
android {
|
||||||
namespace 'com.whispercppdemo'
|
namespace 'com.whispercppdemo'
|
||||||
compileSdk 33
|
compileSdk 34
|
||||||
|
|
||||||
defaultConfig {
|
defaultConfig {
|
||||||
applicationId "com.whispercppdemo"
|
applicationId "com.whispercppdemo"
|
||||||
minSdk 26
|
minSdk 26
|
||||||
targetSdk 32
|
targetSdk 34
|
||||||
versionCode 1
|
versionCode 1
|
||||||
versionName "1.0"
|
versionName "1.0"
|
||||||
|
|
||||||
@ -31,19 +31,19 @@ android {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
compileOptions {
|
compileOptions {
|
||||||
sourceCompatibility JavaVersion.VERSION_1_8
|
sourceCompatibility JavaVersion.VERSION_17
|
||||||
targetCompatibility JavaVersion.VERSION_1_8
|
targetCompatibility JavaVersion.VERSION_17
|
||||||
}
|
}
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = '1.8'
|
jvmTarget = '17'
|
||||||
}
|
}
|
||||||
buildFeatures {
|
buildFeatures {
|
||||||
compose true
|
compose true
|
||||||
}
|
}
|
||||||
composeOptions {
|
composeOptions {
|
||||||
kotlinCompilerExtensionVersion '1.3.1'
|
kotlinCompilerExtensionVersion '1.5.0'
|
||||||
}
|
}
|
||||||
ndkVersion "25.1.8937393"
|
ndkVersion "25.2.9519653"
|
||||||
externalNativeBuild {
|
externalNativeBuild {
|
||||||
cmake {
|
cmake {
|
||||||
path = file("src/main/jni/whisper/CMakeLists.txt")
|
path = file("src/main/jni/whisper/CMakeLists.txt")
|
||||||
@ -57,19 +57,19 @@ android {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation 'androidx.activity:activity-compose:1.6.1'
|
implementation 'androidx.activity:activity-compose:1.7.2'
|
||||||
implementation 'androidx.compose.material:material-icons-core:1.3.1'
|
implementation 'androidx.compose.material:material-icons-core:1.5.0'
|
||||||
implementation 'androidx.compose.material3:material3:1.0.1'
|
implementation 'androidx.compose.material3:material3:1.1.1'
|
||||||
implementation "androidx.compose.ui:ui:1.3.2"
|
implementation "androidx.compose.ui:ui:1.5.0"
|
||||||
implementation "androidx.compose.ui:ui-tooling-preview:1.3.2"
|
implementation "androidx.compose.ui:ui-tooling-preview:1.5.0"
|
||||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.5.1'
|
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1'
|
||||||
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
|
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
|
||||||
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4'
|
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.2'
|
||||||
|
|
||||||
testImplementation 'junit:junit:4.13.2'
|
testImplementation 'junit:junit:4.13.2'
|
||||||
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
|
||||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
||||||
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.3.2"
|
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.5.0"
|
||||||
debugImplementation "androidx.compose.ui:ui-tooling:1.3.2"
|
debugImplementation "androidx.compose.ui:ui-tooling:1.5.0"
|
||||||
debugImplementation "androidx.compose.ui:ui-test-manifest:1.3.2"
|
debugImplementation "androidx.compose.ui:ui-test-manifest:1.5.0"
|
||||||
}
|
}
|
@ -66,7 +66,7 @@ private fun MainScreen(
|
|||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun MessageLog(log: String) {
|
private fun MessageLog(log: String) {
|
||||||
SelectionContainer() {
|
SelectionContainer {
|
||||||
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun printSystemInfo() {
|
private suspend fun printSystemInfo() {
|
||||||
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
|
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()))
|
||||||
}
|
}
|
||||||
|
|
||||||
private suspend fun loadData() {
|
private suspend fun loadData() {
|
||||||
|
@ -13,7 +13,7 @@ import androidx.compose.runtime.SideEffect
|
|||||||
import androidx.compose.ui.graphics.toArgb
|
import androidx.compose.ui.graphics.toArgb
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.platform.LocalView
|
import androidx.compose.ui.platform.LocalView
|
||||||
import androidx.core.view.ViewCompat
|
import androidx.core.view.WindowCompat
|
||||||
|
|
||||||
private val DarkColorScheme = darkColorScheme(
|
private val DarkColorScheme = darkColorScheme(
|
||||||
primary = Purple80,
|
primary = Purple80,
|
||||||
@ -55,8 +55,9 @@ fun WhisperCppDemoTheme(
|
|||||||
val view = LocalView.current
|
val view = LocalView.current
|
||||||
if (!view.isInEditMode) {
|
if (!view.isInEditMode) {
|
||||||
SideEffect {
|
SideEffect {
|
||||||
(view.context as Activity).window.statusBarColor = colorScheme.primary.toArgb()
|
val window = (view.context as Activity).window
|
||||||
ViewCompat.getWindowInsetsController(view)?.isAppearanceLightStatusBars = darkTheme
|
window.statusBarColor = colorScheme.primary.toArgb()
|
||||||
|
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) {
|
|||||||
|
|
||||||
suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
|
suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
|
||||||
require(ptr != 0L)
|
require(ptr != 0L)
|
||||||
WhisperLib.fullTranscribe(ptr, data)
|
val numThreads = WhisperCpuConfig.preferredThreadCount
|
||||||
|
Log.d(LOG_TAG, "Selecting $numThreads threads")
|
||||||
|
WhisperLib.fullTranscribe(ptr, numThreads, data)
|
||||||
val textCount = WhisperLib.getTextSegmentCount(ptr)
|
val textCount = WhisperLib.getTextSegmentCount(ptr)
|
||||||
return@withContext buildString {
|
return@withContext buildString {
|
||||||
for (i in 0 until textCount) {
|
for (i in 0 until textCount) {
|
||||||
@ -126,7 +128,7 @@ private class WhisperLib {
|
|||||||
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
|
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
|
||||||
external fun initContext(modelPath: String): Long
|
external fun initContext(modelPath: String): Long
|
||||||
external fun freeContext(contextPtr: Long)
|
external fun freeContext(contextPtr: Long)
|
||||||
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
|
external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
|
||||||
external fun getTextSegmentCount(contextPtr: Long): Int
|
external fun getTextSegmentCount(contextPtr: Long): Int
|
||||||
external fun getTextSegment(contextPtr: Long, index: Int): String
|
external fun getTextSegment(contextPtr: Long, index: Int): String
|
||||||
external fun getSystemInfo(): String
|
external fun getSystemInfo(): String
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
package com.whispercppdemo.whisper
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import java.io.BufferedReader
|
||||||
|
import java.io.FileReader
|
||||||
|
|
||||||
|
object WhisperCpuConfig {
|
||||||
|
val preferredThreadCount: Int
|
||||||
|
// Always use at least 2 threads:
|
||||||
|
get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
private class CpuInfo(private val lines: List<String>) {
|
||||||
|
private fun getHighPerfCpuCount(): Int = try {
|
||||||
|
getHighPerfCpuCountByFrequencies()
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
|
||||||
|
getHighPerfCpuCountByVariant()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getHighPerfCpuCountByFrequencies(): Int =
|
||||||
|
getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
|
||||||
|
.also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
|
||||||
|
.countDroppingMin()
|
||||||
|
|
||||||
|
private fun getHighPerfCpuCountByVariant(): Int =
|
||||||
|
getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
|
||||||
|
.also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
|
||||||
|
.countKeepingMin()
|
||||||
|
|
||||||
|
private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()
|
||||||
|
|
||||||
|
private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
|
||||||
|
.asSequence()
|
||||||
|
.filter { it.startsWith(property) }
|
||||||
|
.map { mapper(it.substringAfter(':').trim()) }
|
||||||
|
.sorted()
|
||||||
|
.toList()
|
||||||
|
|
||||||
|
|
||||||
|
private fun List<Int>.countDroppingMin(): Int {
|
||||||
|
val min = min()
|
||||||
|
return count { it > min }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun List<Int>.countKeepingMin(): Int {
|
||||||
|
val min = min()
|
||||||
|
return count { it == min }
|
||||||
|
}
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val LOG_TAG = "WhisperCpuConfig"
|
||||||
|
|
||||||
|
fun getHighPerfCpuCount(): Int = try {
|
||||||
|
readCpuInfo().getHighPerfCpuCount()
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.d(LOG_TAG, "Couldn't read CPU info", e)
|
||||||
|
// Our best guess -- just return the # of CPUs minus 4.
|
||||||
|
(Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun readCpuInfo() = CpuInfo(
|
||||||
|
BufferedReader(FileReader("/proc/cpuinfo"))
|
||||||
|
.useLines { it.toList() }
|
||||||
|
)
|
||||||
|
|
||||||
|
private fun getMaxCpuFrequency(cpuIndex: Int): Int {
|
||||||
|
val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
|
||||||
|
val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
|
||||||
|
return maxFreq.toInt()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -8,6 +8,7 @@ set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
|
|||||||
set(
|
set(
|
||||||
SOURCE_FILES
|
SOURCE_FILES
|
||||||
${WHISPER_LIB_DIR}/ggml.c
|
${WHISPER_LIB_DIR}/ggml.c
|
||||||
|
${WHISPER_LIB_DIR}/ggml-alloc.c
|
||||||
${WHISPER_LIB_DIR}/whisper.cpp
|
${WHISPER_LIB_DIR}/whisper.cpp
|
||||||
${CMAKE_SOURCE_DIR}/jni.c
|
${CMAKE_SOURCE_DIR}/jni.c
|
||||||
)
|
)
|
||||||
|
@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(
|
|||||||
|
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
||||||
JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
|
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
|
||||||
UNUSED(thiz);
|
UNUSED(thiz);
|
||||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||||
|
|
||||||
// Leave 2 processors free (i.e. the high-efficiency cores).
|
|
||||||
int max_threads = max(1, min(8, get_nprocs() - 2));
|
|
||||||
LOGI("Selecting %d threads", max_threads);
|
|
||||||
|
|
||||||
// The below adapted from the Objective-C iOS sample
|
// The below adapted from the Objective-C iOS sample
|
||||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
params.print_realtime = true;
|
params.print_realtime = true;
|
||||||
@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
|||||||
params.print_special = false;
|
params.print_special = false;
|
||||||
params.translate = false;
|
params.translate = false;
|
||||||
params.language = "en";
|
params.language = "en";
|
||||||
params.n_threads = max_threads;
|
params.n_threads = num_threads;
|
||||||
params.offset_ms = 0;
|
params.offset_ms = 0;
|
||||||
params.no_context = true;
|
params.no_context = true;
|
||||||
params.single_segment = false;
|
params.single_segment = false;
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="utf-8"?>
|
|
||||||
<resources>
|
|
||||||
<color name="purple_200">#FFBB86FC</color>
|
|
||||||
<color name="purple_500">#FF6200EE</color>
|
|
||||||
<color name="purple_700">#FF3700B3</color>
|
|
||||||
<color name="teal_200">#FF03DAC5</color>
|
|
||||||
<color name="teal_700">#FF018786</color>
|
|
||||||
<color name="black">#FF000000</color>
|
|
||||||
<color name="white">#FFFFFFFF</color>
|
|
||||||
</resources>
|
|
@ -1,6 +1,6 @@
|
|||||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||||
plugins {
|
plugins {
|
||||||
id 'com.android.application' version '7.3.1' apply false
|
id 'com.android.application' version '8.1.1' apply false
|
||||||
id 'com.android.library' version '7.3.1' apply false
|
id 'com.android.library' version '8.1.1' apply false
|
||||||
id 'org.jetbrains.kotlin.android' version '1.7.10' apply false
|
id 'org.jetbrains.kotlin.android' version '1.9.0' apply false
|
||||||
}
|
}
|
@ -1,6 +1,6 @@
|
|||||||
#Wed Dec 14 10:37:24 EST 2022
|
#Wed Dec 14 10:37:24 EST 2022
|
||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
|
@ -28,6 +28,8 @@ This can significantly improve the performance of the transcription:
|
|||||||
|
|
||||||
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
|
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
|
||||||
|
|
||||||
|
## Core ML
|
||||||
|
|
||||||
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
|
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
|
||||||
|
|
||||||
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
|
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
|
||||||
@ -35,3 +37,13 @@ If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DW
|
|||||||
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
|
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
|
||||||
|
|
||||||
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
||||||
|
|
||||||
|
## Metal
|
||||||
|
|
||||||
|
You can also enable Metal to make the inference run on the GPU of your device. This might or might not be more efficient
|
||||||
|
compared to Core ML depending on the model and device that you use.
|
||||||
|
|
||||||
|
To enable Metal, just add `-DGGML_USE_METAL` instead off the `-DWHISPER_USE_COREML` flag and you are ready.
|
||||||
|
This will make both the Encoder and the Decoder run on the GPU.
|
||||||
|
|
||||||
|
If you want to run the Encoder with Core ML and the Decoder with Metal then simply add both `-DWHISPER_USE_COREML -DGGML_USE_METAL` flags. That's all!
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
objects = {
|
objects = {
|
||||||
|
|
||||||
/* Begin PBXBuildFile section */
|
/* Begin PBXBuildFile section */
|
||||||
|
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; };
|
||||||
|
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
|
||||||
|
184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; };
|
||||||
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
|
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
|
||||||
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
|
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
|
||||||
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
|
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
|
||||||
@ -14,7 +17,7 @@
|
|||||||
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
|
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
|
||||||
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
|
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
|
||||||
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
|
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
|
||||||
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; };
|
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
|
||||||
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
|
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
|
||||||
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
|
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
|
||||||
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
|
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
|
||||||
@ -23,7 +26,24 @@
|
|||||||
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
|
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
|
||||||
/* End PBXBuildFile section */
|
/* End PBXBuildFile section */
|
||||||
|
|
||||||
|
/* Begin PBXCopyFilesBuildPhase section */
|
||||||
|
184447202AB21B25007D6BFE /* CopyFiles */ = {
|
||||||
|
isa = PBXCopyFilesBuildPhase;
|
||||||
|
buildActionMask = 2147483647;
|
||||||
|
dstPath = "";
|
||||||
|
dstSubfolderSpec = 7;
|
||||||
|
files = (
|
||||||
|
184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */,
|
||||||
|
);
|
||||||
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
|
};
|
||||||
|
/* End PBXCopyFilesBuildPhase section */
|
||||||
|
|
||||||
/* Begin PBXFileReference section */
|
/* Begin PBXFileReference section */
|
||||||
|
184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml-alloc.c"; sourceTree = "<group>"; };
|
||||||
|
184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml-alloc.h"; sourceTree = "<group>"; };
|
||||||
|
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml-metal.m"; sourceTree = "<group>"; };
|
||||||
|
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml-metal.metal"; sourceTree = "<group>"; };
|
||||||
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||||
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
|
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
|
||||||
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
|
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
|
||||||
@ -80,6 +100,10 @@
|
|||||||
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
|
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
|
||||||
isa = PBXGroup;
|
isa = PBXGroup;
|
||||||
children = (
|
children = (
|
||||||
|
1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
|
||||||
|
1844471B2AB21655007D6BFE /* ggml-metal.m */,
|
||||||
|
184447182AB211A2007D6BFE /* ggml-alloc.c */,
|
||||||
|
184447192AB211A2007D6BFE /* ggml-alloc.h */,
|
||||||
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
|
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
|
||||||
7FE342442A0C3FA20015A058 /* coreml */,
|
7FE342442A0C3FA20015A058 /* coreml */,
|
||||||
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
|
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
|
||||||
@ -126,6 +150,7 @@
|
|||||||
18627C7229052BDF00BD2A04 /* Sources */,
|
18627C7229052BDF00BD2A04 /* Sources */,
|
||||||
18627C7329052BDF00BD2A04 /* Frameworks */,
|
18627C7329052BDF00BD2A04 /* Frameworks */,
|
||||||
18627C7429052BDF00BD2A04 /* Resources */,
|
18627C7429052BDF00BD2A04 /* Resources */,
|
||||||
|
184447202AB21B25007D6BFE /* CopyFiles */,
|
||||||
);
|
);
|
||||||
buildRules = (
|
buildRules = (
|
||||||
);
|
);
|
||||||
@ -194,8 +219,10 @@
|
|||||||
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
|
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
|
||||||
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
|
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
|
||||||
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
|
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
|
||||||
|
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
|
||||||
18627C8C29052BE000BD2A04 /* main.m in Sources */,
|
18627C8C29052BE000BD2A04 /* main.m in Sources */,
|
||||||
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
|
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
|
||||||
|
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
|
||||||
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
|
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
|
||||||
);
|
);
|
||||||
runOnlyForDeploymentPostprocessing = 0;
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
|
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
|
||||||
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
|
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
|
||||||
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
|
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
|
||||||
|
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
|
||||||
/* End PBXBuildFile section */
|
/* End PBXBuildFile section */
|
||||||
|
|
||||||
/* Begin PBXFileReference section */
|
/* Begin PBXFileReference section */
|
||||||
@ -41,6 +42,8 @@
|
|||||||
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
|
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
|
||||||
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
|
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
|
||||||
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
|
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
|
||||||
|
18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
|
||||||
|
18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
|
||||||
/* End PBXFileReference section */
|
/* End PBXFileReference section */
|
||||||
|
|
||||||
/* Begin PBXFrameworksBuildPhase section */
|
/* Begin PBXFrameworksBuildPhase section */
|
||||||
@ -124,6 +127,8 @@
|
|||||||
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
|
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
|
||||||
isa = PBXGroup;
|
isa = PBXGroup;
|
||||||
children = (
|
children = (
|
||||||
|
18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
|
||||||
|
18AED4802AB21F2B009D854F /* ggml-alloc.h */,
|
||||||
0AAC5DC929539EB0003032C3 /* ggml.c */,
|
0AAC5DC929539EB0003032C3 /* ggml.c */,
|
||||||
0AAC5DCA29539EB0003032C3 /* ggml.h */,
|
0AAC5DCA29539EB0003032C3 /* ggml.h */,
|
||||||
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
|
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
|
||||||
@ -242,6 +247,7 @@
|
|||||||
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
|
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
|
||||||
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
|
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
|
||||||
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
|
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
|
||||||
|
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
|
||||||
);
|
);
|
||||||
runOnlyForDeploymentPostprocessing = 0;
|
runOnlyForDeploymentPostprocessing = 0;
|
||||||
};
|
};
|
||||||
@ -369,7 +375,7 @@
|
|||||||
CODE_SIGN_STYLE = Automatic;
|
CODE_SIGN_STYLE = Automatic;
|
||||||
CURRENT_PROJECT_VERSION = 1;
|
CURRENT_PROJECT_VERSION = 1;
|
||||||
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
||||||
DEVELOPMENT_TEAM = 3TZ9BM962G;
|
DEVELOPMENT_TEAM = P8JZH34X63;
|
||||||
ENABLE_HARDENED_RUNTIME = YES;
|
ENABLE_HARDENED_RUNTIME = YES;
|
||||||
ENABLE_PREVIEWS = YES;
|
ENABLE_PREVIEWS = YES;
|
||||||
GENERATE_INFOPLIST_FILE = YES;
|
GENERATE_INFOPLIST_FILE = YES;
|
||||||
@ -410,7 +416,7 @@
|
|||||||
CODE_SIGN_STYLE = Automatic;
|
CODE_SIGN_STYLE = Automatic;
|
||||||
CURRENT_PROJECT_VERSION = 1;
|
CURRENT_PROJECT_VERSION = 1;
|
||||||
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
||||||
DEVELOPMENT_TEAM = 3TZ9BM962G;
|
DEVELOPMENT_TEAM = P8JZH34X63;
|
||||||
ENABLE_HARDENED_RUNTIME = YES;
|
ENABLE_HARDENED_RUNTIME = YES;
|
||||||
ENABLE_PREVIEWS = YES;
|
ENABLE_PREVIEWS = YES;
|
||||||
GENERATE_INFOPLIST_FILE = YES;
|
GENERATE_INFOPLIST_FILE = YES;
|
||||||
|
@ -44,27 +44,26 @@ if [ "$encoder_only" -eq 0 ]; then
|
|||||||
printf "\n"
|
printf "\n"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "| CPU | OS | Config | Model | Th | Load | Enc. | Commit |\n"
|
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||||
printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ------ |\n"
|
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||||
|
|
||||||
for model in "${models[@]}"; do
|
for model in "${models[@]}"; do
|
||||||
# run once to heat-up the cache
|
|
||||||
./bench -m ./models/ggml-$model.bin -t $n_threads 2>/dev/null 1>/dev/null
|
|
||||||
|
|
||||||
# actual run
|
# actual run
|
||||||
# store stderr output in a variable in order to parse it later
|
# store stderr output in a variable in order to parse it later
|
||||||
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
|
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
|
||||||
ret=$?
|
ret=$?
|
||||||
|
|
||||||
# parse the output:
|
# parse the output:
|
||||||
load_time=$(echo "$output" | grep "load time" | awk '{print $5}')
|
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
||||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $5}')
|
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
||||||
|
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
||||||
system_info=$(echo "$output" | grep "system_info")
|
system_info=$(echo "$output" | grep "system_info")
|
||||||
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
||||||
|
|
||||||
# floor to milliseconds
|
# floor to milliseconds
|
||||||
load_time=${load_time%.*}
|
#encode_time=${encode_time%.*}
|
||||||
encode_time=${encode_time%.*}
|
#decode_time=${decode_time%.*}
|
||||||
|
#prompt_time=${prompt_time%.*}
|
||||||
|
|
||||||
config=""
|
config=""
|
||||||
|
|
||||||
@ -84,9 +83,13 @@ for model in "${models[@]}"; do
|
|||||||
config="$config COREML"
|
config="$config COREML"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ $system_info == *"METAL = 1"* ]]; then
|
||||||
|
config="$config METAL"
|
||||||
|
fi
|
||||||
|
|
||||||
commit=$(git rev-parse --short HEAD)
|
commit=$(git rev-parse --short HEAD)
|
||||||
|
|
||||||
if [ $ret -eq 0 ]; then
|
if [ $ret -eq 0 ]; then
|
||||||
printf "| <todo> | <todo> | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n"
|
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
222
extra/bench.py
Normal file
222
extra/bench.py
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import re
|
||||||
|
import csv
|
||||||
|
import wave
|
||||||
|
import contextlib
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
# Custom action to handle comma-separated list
|
||||||
|
class ListAction(argparse.Action):
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
setattr(namespace, self.dest, [int(val) for val in values.split(",")])
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
|
||||||
|
|
||||||
|
# Define the argument to accept a list
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--threads",
|
||||||
|
dest="threads",
|
||||||
|
action=ListAction,
|
||||||
|
default=[4],
|
||||||
|
help="List of thread counts to benchmark (comma-separated, default: 4)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--processors",
|
||||||
|
dest="processors",
|
||||||
|
action=ListAction,
|
||||||
|
default=[1],
|
||||||
|
help="List of processor counts to benchmark (comma-separated, default: 1)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--filename",
|
||||||
|
type=str,
|
||||||
|
default="./samples/jfk.wav",
|
||||||
|
help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the command line arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
sample_file = args.filename
|
||||||
|
|
||||||
|
threads = args.threads
|
||||||
|
processors = args.processors
|
||||||
|
|
||||||
|
# Define the models, threads, and processor counts to benchmark
|
||||||
|
models = [
|
||||||
|
"ggml-tiny.en.bin",
|
||||||
|
"ggml-tiny.bin",
|
||||||
|
"ggml-base.en.bin",
|
||||||
|
"ggml-base.bin",
|
||||||
|
"ggml-small.en.bin",
|
||||||
|
"ggml-small.bin",
|
||||||
|
"ggml-medium.en.bin",
|
||||||
|
"ggml-medium.bin",
|
||||||
|
"ggml-large.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
metal_device = ""
|
||||||
|
|
||||||
|
# Initialize a dictionary to hold the results
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
gitHashHeader = "Commit"
|
||||||
|
modelHeader = "Model"
|
||||||
|
hardwareHeader = "Hardware"
|
||||||
|
recordingLengthHeader = "Recording Length (seconds)"
|
||||||
|
threadHeader = "Thread"
|
||||||
|
processorCountHeader = "Processor Count"
|
||||||
|
loadTimeHeader = "Load Time (ms)"
|
||||||
|
sampleTimeHeader = "Sample Time (ms)"
|
||||||
|
encodeTimeHeader = "Encode Time (ms)"
|
||||||
|
decodeTimeHeader = "Decode Time (ms)"
|
||||||
|
sampleTimePerRunHeader = "Sample Time per Run (ms)"
|
||||||
|
encodeTimePerRunHeader = "Encode Time per Run (ms)"
|
||||||
|
decodeTimePerRunHeader = "Decode Time per Run (ms)"
|
||||||
|
totalTimeHeader = "Total Time (ms)"
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_exists(file: str) -> bool:
|
||||||
|
return os.path.isfile(file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_short_hash() -> str:
|
||||||
|
try:
|
||||||
|
return (
|
||||||
|
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
||||||
|
.decode()
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def wav_file_length(file: str = sample_file) -> float:
|
||||||
|
with contextlib.closing(wave.open(file, "r")) as f:
|
||||||
|
frames = f.getnframes()
|
||||||
|
rate = f.getframerate()
|
||||||
|
duration = frames / float(rate)
|
||||||
|
return duration
|
||||||
|
|
||||||
|
|
||||||
|
def extract_metrics(output: str, label: str) -> tuple[float, float]:
|
||||||
|
match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
|
||||||
|
time = float(match.group(1)) if match else None
|
||||||
|
runs = float(match.group(2)) if match else None
|
||||||
|
return time, runs
|
||||||
|
|
||||||
|
|
||||||
|
def extract_device(output: str) -> str:
|
||||||
|
match = re.search(r"picking default device: (.*)", output)
|
||||||
|
device = match.group(1) if match else "Not found"
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
# Check if the sample file exists
|
||||||
|
if not check_file_exists(sample_file):
|
||||||
|
raise FileNotFoundError(f"Sample file {sample_file} not found")
|
||||||
|
|
||||||
|
recording_length = wav_file_length()
|
||||||
|
|
||||||
|
|
||||||
|
# Check that all models exist
|
||||||
|
# Filter out models from list that are not downloaded
|
||||||
|
filtered_models = []
|
||||||
|
for model in models:
|
||||||
|
if check_file_exists(f"models/{model}"):
|
||||||
|
filtered_models.append(model)
|
||||||
|
else:
|
||||||
|
print(f"Model {model} not found, removing from list")
|
||||||
|
|
||||||
|
models = filtered_models
|
||||||
|
|
||||||
|
# Loop over each combination of parameters
|
||||||
|
for model in filtered_models:
|
||||||
|
for thread in threads:
|
||||||
|
for processor_count in processors:
|
||||||
|
# Construct the command to run
|
||||||
|
cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
|
||||||
|
# Run the command and get the output
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
while process.poll() is None:
|
||||||
|
output += process.stdout.read().decode()
|
||||||
|
|
||||||
|
# Parse the output
|
||||||
|
load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
|
||||||
|
load_time = float(load_time_match.group(1)) if load_time_match else None
|
||||||
|
|
||||||
|
metal_device = extract_device(output)
|
||||||
|
sample_time, sample_runs = extract_metrics(output, "sample time")
|
||||||
|
encode_time, encode_runs = extract_metrics(output, "encode time")
|
||||||
|
decode_time, decode_runs = extract_metrics(output, "decode time")
|
||||||
|
|
||||||
|
total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
|
||||||
|
total_time = float(total_time_match.group(1)) if total_time_match else None
|
||||||
|
|
||||||
|
model_name = model.replace("ggml-", "").replace(".bin", "")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
|
||||||
|
)
|
||||||
|
# Store the times in the results dictionary
|
||||||
|
results[(model_name, thread, processor_count)] = {
|
||||||
|
loadTimeHeader: load_time,
|
||||||
|
sampleTimeHeader: sample_time,
|
||||||
|
encodeTimeHeader: encode_time,
|
||||||
|
decodeTimeHeader: decode_time,
|
||||||
|
sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
|
||||||
|
encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
|
||||||
|
decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
|
||||||
|
totalTimeHeader: total_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Write the results to a CSV file
|
||||||
|
with open("benchmark_results.csv", "w", newline="") as csvfile:
|
||||||
|
fieldnames = [
|
||||||
|
gitHashHeader,
|
||||||
|
modelHeader,
|
||||||
|
hardwareHeader,
|
||||||
|
recordingLengthHeader,
|
||||||
|
threadHeader,
|
||||||
|
processorCountHeader,
|
||||||
|
loadTimeHeader,
|
||||||
|
sampleTimeHeader,
|
||||||
|
encodeTimeHeader,
|
||||||
|
decodeTimeHeader,
|
||||||
|
sampleTimePerRunHeader,
|
||||||
|
encodeTimePerRunHeader,
|
||||||
|
decodeTimePerRunHeader,
|
||||||
|
totalTimeHeader,
|
||||||
|
]
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
shortHash = get_git_short_hash()
|
||||||
|
# Sort the results by total time in ascending order
|
||||||
|
sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
|
||||||
|
for params, times in sorted_results:
|
||||||
|
row = {
|
||||||
|
gitHashHeader: shortHash,
|
||||||
|
modelHeader: params[0],
|
||||||
|
hardwareHeader: metal_device,
|
||||||
|
recordingLengthHeader: recording_length,
|
||||||
|
threadHeader: params[1],
|
||||||
|
processorCountHeader: params[2],
|
||||||
|
}
|
||||||
|
row.update(times)
|
||||||
|
writer.writerow(row)
|
@ -1,6 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
||||||
|
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
|
||||||
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
||||||
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
||||||
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
||||||
@ -9,6 +10,7 @@ cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
|||||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
||||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
||||||
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
||||||
|
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
|
||||||
cp -rpv ../ggml/examples/common.h ./examples/common.h
|
cp -rpv ../ggml/examples/common.h ./examples/common.h
|
||||||
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
|
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
|
||||||
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
|
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
|
||||||
|
183
ggml-alloc.c
183
ggml-alloc.c
@ -6,6 +6,26 @@
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#ifdef __has_include
|
||||||
|
#if __has_include(<unistd.h>)
|
||||||
|
#include <unistd.h>
|
||||||
|
#if defined(_POSIX_MAPPED_FILES)
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
#define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#include <memoryapi.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#define UNUSED(x) (void)(x)
|
#define UNUSED(x) (void)(x)
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
||||||
@ -99,15 +119,28 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
|
||||||
return ggml_nbytes(tensor);
|
return ggml_nbytes(tensor);
|
||||||
|
|
||||||
UNUSED(alloc);
|
UNUSED(alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if a tensor is allocated by this buffer
|
||||||
|
static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
|
||||||
|
void * ptr = tensor->data;
|
||||||
|
return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||||
|
return t->view_src != NULL;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
|
GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
|
||||||
|
GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
|
||||||
|
#endif
|
||||||
|
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
|
||||||
size = aligned_offset(NULL, size, alloc->alignment);
|
size = aligned_offset(NULL, size, alloc->alignment);
|
||||||
|
|
||||||
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
||||||
@ -131,9 +164,9 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
|
|||||||
if (best_fit_block == -1) {
|
if (best_fit_block == -1) {
|
||||||
// the last block is our last resort
|
// the last block is our last resort
|
||||||
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
||||||
|
max_avail = MAX(max_avail, block->size);
|
||||||
if (block->size >= size) {
|
if (block->size >= size) {
|
||||||
best_fit_block = alloc->n_free_blocks - 1;
|
best_fit_block = alloc->n_free_blocks - 1;
|
||||||
max_avail = MAX(max_avail, block->size);
|
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
||||||
__func__, size, max_avail);
|
__func__, size, max_avail);
|
||||||
@ -173,17 +206,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
||||||
static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||||
void * ptr = tensor->data;
|
void * ptr = tensor->data;
|
||||||
|
|
||||||
if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
|
if (ggml_allocr_is_own(alloc, tensor) == false) {
|
||||||
// the tensor was not allocated in this buffer
|
// the tensor was not allocated in this buffer
|
||||||
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
||||||
// the easiest way to deal with this is just to ignore it
|
// the easiest way to deal with this is just to ignore it
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
|
||||||
size = aligned_offset(NULL, size, alloc->alignment);
|
size = aligned_offset(NULL, size, alloc->alignment);
|
||||||
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
||||||
|
|
||||||
@ -277,17 +310,68 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
|
|||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// address and size of the buffer when measuring
|
// OS specific functions to allocate and free uncommitted virtual memory
|
||||||
// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
|
static void * alloc_vmem(size_t size) {
|
||||||
static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
|
#if defined(_WIN32)
|
||||||
static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
|
return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
|
||||||
|
#elif defined(_POSIX_MAPPED_FILES)
|
||||||
|
void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
|
||||||
|
if (ptr == MAP_FAILED) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return ptr;
|
||||||
|
#else
|
||||||
|
// use a fixed address for other platforms
|
||||||
|
uintptr_t base_addr = (uintptr_t)-size - 0x100;
|
||||||
|
return (void *)base_addr;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static void free_vmem(void * base_addr, size_t size) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
VirtualFree(base_addr, 0, MEM_RELEASE);
|
||||||
|
UNUSED(size);
|
||||||
|
#elif defined(_POSIX_MAPPED_FILES)
|
||||||
|
munmap(base_addr, size);
|
||||||
|
#else
|
||||||
|
// nothing to do
|
||||||
|
UNUSED(base_addr);
|
||||||
|
UNUSED(size);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocate uncommitted virtual memory to measure the size of the graph
|
||||||
|
static void alloc_measure_vmem(void ** base_addr, size_t * size) {
|
||||||
|
// 128GB for 64-bit, 1GB for 32-bit
|
||||||
|
*size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
|
||||||
|
do {
|
||||||
|
*base_addr = alloc_vmem(*size);
|
||||||
|
if (*base_addr != NULL) {
|
||||||
|
AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// try again with half the size
|
||||||
|
*size /= 2;
|
||||||
|
} while (*size > 0);
|
||||||
|
|
||||||
|
GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void free_measure_vmem(void * base_addr, size_t size) {
|
||||||
|
free_vmem(base_addr, size);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
||||||
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
||||||
|
|
||||||
|
void * base_addr;
|
||||||
|
size_t size;
|
||||||
|
|
||||||
|
alloc_measure_vmem(&base_addr, &size);
|
||||||
|
|
||||||
*alloc = (struct ggml_allocr){
|
*alloc = (struct ggml_allocr){
|
||||||
/*.data = */ MEASURE_BASE_ADDR,
|
/*.data = */ base_addr,
|
||||||
/*.size = */ MEASURE_MAX_SIZE,
|
/*.size = */ size,
|
||||||
/*.alignment = */ alignment,
|
/*.alignment = */ alignment,
|
||||||
/*.n_free_blocks = */ 0,
|
/*.n_free_blocks = */ 0,
|
||||||
/*.free_blocks = */ {{0}},
|
/*.free_blocks = */ {{0}},
|
||||||
@ -307,6 +391,9 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
||||||
|
if (alloc->measure) {
|
||||||
|
free_measure_vmem(alloc->data, alloc->size);
|
||||||
|
}
|
||||||
free(alloc);
|
free(alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,11 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
|
|||||||
|
|
||||||
//////////// compute graph allocator
|
//////////// compute graph allocator
|
||||||
|
|
||||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
|
||||||
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
|
||||||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
||||||
if (a->type != b->type) {
|
if (a->type != b->type) {
|
||||||
return false;
|
return false;
|
||||||
@ -336,28 +418,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
|
||||||
switch (t->op) {
|
|
||||||
case GGML_OP_PERMUTE:
|
|
||||||
case GGML_OP_RESHAPE:
|
|
||||||
case GGML_OP_TRANSPOSE:
|
|
||||||
case GGML_OP_VIEW:
|
|
||||||
return t->src[0];
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
return t->src[1];
|
|
||||||
default:
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
|
||||||
struct ggml_tensor * parent = t;
|
|
||||||
do {
|
|
||||||
parent = get_view_parent(parent);
|
|
||||||
} while (ggml_is_view(parent));
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
@ -365,7 +425,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_ACC:
|
|
||||||
case GGML_OP_SUB:
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
@ -375,10 +434,8 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
|||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_SET:
|
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_ADD_REL_POS:
|
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -390,24 +447,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|||||||
struct hash_node * ht = alloc->hash_table;
|
struct hash_node * ht = alloc->hash_table;
|
||||||
if (node->data == NULL) {
|
if (node->data == NULL) {
|
||||||
if (ggml_is_view(node)) {
|
if (ggml_is_view(node)) {
|
||||||
size_t offset;
|
assert(node->view_src->data != NULL);
|
||||||
switch(node->op) {
|
node->data = (char *)node->view_src->data + node->view_offs;
|
||||||
case GGML_OP_VIEW:
|
|
||||||
memcpy(&offset, node->op_params, sizeof(size_t));
|
|
||||||
node->data = (char *) node->src[0]->data + offset;
|
|
||||||
break;
|
|
||||||
case GGML_OP_PERMUTE:
|
|
||||||
case GGML_OP_RESHAPE:
|
|
||||||
case GGML_OP_TRANSPOSE:
|
|
||||||
node->data = node->src[0]->data;
|
|
||||||
break;
|
|
||||||
case GGML_OP_CPY:
|
|
||||||
node->data = node->src[1]->data;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unknown view op");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// see if we can reuse a parent's buffer (inplace)
|
// see if we can reuse a parent's buffer (inplace)
|
||||||
if (ggml_op_can_inplace(node->op)) {
|
if (ggml_op_can_inplace(node->op)) {
|
||||||
@ -418,8 +459,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if the node's data is external, then we cannot re-use it
|
// if the node's data is external, then we cannot re-use it
|
||||||
if ((char *) parent->data < (char *) alloc->data ||
|
if (ggml_allocr_is_own(alloc, parent) == false) {
|
||||||
(char *) parent->data >= ((char *) alloc->data + alloc->size)) {
|
|
||||||
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
|
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -427,7 +467,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|||||||
struct hash_node * p_hn = hash_get(ht, parent);
|
struct hash_node * p_hn = hash_get(ht, parent);
|
||||||
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
||||||
if (ggml_is_view(parent)) {
|
if (ggml_is_view(parent)) {
|
||||||
struct ggml_tensor * view_src = get_view_source(parent);
|
struct ggml_tensor * view_src = parent->view_src;
|
||||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||||
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
||||||
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
||||||
@ -453,7 +493,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_allocator_alloc_graph_tensors_n(
|
static size_t ggml_allocr_alloc_graph_tensors_n(
|
||||||
struct ggml_allocr * alloc,
|
struct ggml_allocr * alloc,
|
||||||
struct ggml_cgraph ** graphs, int n_graphs,
|
struct ggml_cgraph ** graphs, int n_graphs,
|
||||||
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
||||||
@ -469,7 +509,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|||||||
struct ggml_tensor * node = gf->nodes[i];
|
struct ggml_tensor * node = gf->nodes[i];
|
||||||
|
|
||||||
if (ggml_is_view(node)) {
|
if (ggml_is_view(node)) {
|
||||||
struct ggml_tensor * view_src = get_view_source(node);
|
struct ggml_tensor * view_src = node->view_src;
|
||||||
hash_get(ht, view_src)->n_views += 1;
|
hash_get(ht, view_src)->n_views += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -531,7 +571,6 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|||||||
AT_PRINTF("\n");
|
AT_PRINTF("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// update parents
|
// update parents
|
||||||
// update immediately if there is no parse_seq
|
// update immediately if there is no parse_seq
|
||||||
// update only at barriers if there is parse_seq
|
// update only at barriers if there is parse_seq
|
||||||
@ -554,17 +593,17 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|||||||
|
|
||||||
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
||||||
if (ggml_is_view(parent)) {
|
if (ggml_is_view(parent)) {
|
||||||
struct ggml_tensor * view_src = get_view_source(parent);
|
struct ggml_tensor * view_src = parent->view_src;
|
||||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||||
view_src_hn->n_views -= 1;
|
view_src_hn->n_views -= 1;
|
||||||
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
|
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
|
||||||
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
||||||
ggml_allocator_free_tensor(alloc, view_src);
|
ggml_allocr_free_tensor(alloc, view_src);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (parent->data != node->data) {
|
if (parent->data != node->data) {
|
||||||
ggml_allocator_free_tensor(alloc, parent);
|
ggml_allocr_free_tensor(alloc, parent);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -581,7 +620,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|||||||
for (int i = 0; outputs[g][i] != NULL; i++) {
|
for (int i = 0; outputs[g][i] != NULL; i++) {
|
||||||
struct ggml_tensor * output = outputs[g][i];
|
struct ggml_tensor * output = outputs[g][i];
|
||||||
AT_PRINTF("output: %s\n", output->name);
|
AT_PRINTF("output: %s\n", output->name);
|
||||||
ggml_allocator_free_tensor(alloc, output);
|
ggml_allocr_free_tensor(alloc, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -590,5 +629,5 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
||||||
return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
||||||
}
|
}
|
||||||
|
34
ggml-cuda.cu
34
ggml-cuda.cu
@ -4086,7 +4086,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
|||||||
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
|
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||||
|
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int half_n_dims = ncols/4;
|
const int half_n_dims = ncols/4;
|
||||||
|
|
||||||
@ -4098,8 +4099,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
|
|
||||||
const float col_theta_scale = powf(theta_scale, col);
|
const float col_theta_scale = powf(theta_scale, col);
|
||||||
|
const float p = p0 + p_delta*(row/p_delta_rows);
|
||||||
|
|
||||||
const float theta = p*col_theta_scale;
|
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
|
||||||
const float sin_theta = sinf(theta);
|
const float sin_theta = sinf(theta);
|
||||||
const float cos_theta = cosf(theta);
|
const float cos_theta = cosf(theta);
|
||||||
|
|
||||||
@ -4109,7 +4111,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||||
|
|
||||||
const float block_theta = block_p*col_theta_scale;
|
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
|
||||||
const float sin_block_theta = sinf(block_theta);
|
const float sin_block_theta = sinf(block_theta);
|
||||||
const float cos_block_theta = cosf(block_theta);
|
const float cos_block_theta = cosf(block_theta);
|
||||||
|
|
||||||
@ -4984,12 +4986,13 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
|
|||||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
|
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||||
GGML_ASSERT(nrows % 4 == 0);
|
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||||
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
GGML_ASSERT(ncols % 4 == 0);
|
||||||
const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
|
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
||||||
|
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
||||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||||
@ -5723,22 +5726,18 @@ inline void ggml_cuda_op_rope(
|
|||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
const bool is_glm = mode & 4;
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (is_glm) {
|
if (is_glm) {
|
||||||
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
|
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
|
||||||
const float id_p = min(p, n_ctx - 2.f);
|
|
||||||
const float block_p = max(p - (n_ctx - 2.f), 0.f);
|
|
||||||
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
|
|
||||||
} else if (is_neox) {
|
} else if (is_neox) {
|
||||||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
|
||||||
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||||
} else {
|
} else {
|
||||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
|
||||||
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5746,6 +5745,7 @@ inline void ggml_cuda_op_rope(
|
|||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src0_ddq_i;
|
(void) src0_ddq_i;
|
||||||
(void) src1_ddf_i;
|
(void) src1_ddf_i;
|
||||||
|
(void) i02;
|
||||||
(void) i1;
|
(void) i1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5781,6 +5781,7 @@ inline void ggml_cuda_op_alibi(
|
|||||||
(void) src1;
|
(void) src1;
|
||||||
(void) src0_ddq_i;
|
(void) src0_ddq_i;
|
||||||
(void) src1_ddf_i;
|
(void) src1_ddf_i;
|
||||||
|
(void) i02;
|
||||||
(void) i1;
|
(void) i1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6400,10 +6401,7 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
|
|||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
|
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
|
||||||
|
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
|
||||||
const bool is_glm = mode & 4;
|
|
||||||
|
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
115
ggml-metal.m
115
ggml-metal.m
@ -63,7 +63,10 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(relu);
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
GGML_METAL_DECL_KERNEL(gelu);
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
GGML_METAL_DECL_KERNEL(soft_max);
|
GGML_METAL_DECL_KERNEL(soft_max);
|
||||||
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
@ -75,8 +78,10 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -85,6 +90,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
||||||
@ -117,14 +123,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|||||||
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
metal_printf("%s: allocating\n", __func__);
|
metal_printf("%s: allocating\n", __func__);
|
||||||
|
|
||||||
// Show all the Metal device instances in the system
|
|
||||||
NSArray * devices = MTLCopyAllDevices();
|
|
||||||
id <MTLDevice> device;
|
id <MTLDevice> device;
|
||||||
NSString * s;
|
NSString * s;
|
||||||
|
|
||||||
|
#if TARGET_OS_OSX
|
||||||
|
// Show all the Metal device instances in the system
|
||||||
|
NSArray * devices = MTLCopyAllDevices();
|
||||||
for (device in devices) {
|
for (device in devices) {
|
||||||
s = [device name];
|
s = [device name];
|
||||||
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Pick and show default Metal device
|
// Pick and show default Metal device
|
||||||
device = MTLCreateSystemDefaultDevice();
|
device = MTLCreateSystemDefaultDevice();
|
||||||
@ -139,14 +148,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
ctx->n_buffers = 0;
|
ctx->n_buffers = 0;
|
||||||
ctx->concur_list_len = 0;
|
ctx->concur_list_len = 0;
|
||||||
|
|
||||||
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||||
|
|
||||||
#if 0
|
#ifdef GGML_SWIFT
|
||||||
// compile from source string and show compile log
|
// load the default.metallib file
|
||||||
{
|
{
|
||||||
NSError * error = nil;
|
NSError * error = nil;
|
||||||
|
|
||||||
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||||
|
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
||||||
|
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
||||||
|
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
||||||
|
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
||||||
|
|
||||||
|
// Load the metallib file into a Metal library
|
||||||
|
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
return NULL;
|
return NULL;
|
||||||
@ -207,7 +224,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(relu);
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
GGML_METAL_ADD_KERNEL(gelu);
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
GGML_METAL_ADD_KERNEL(soft_max);
|
GGML_METAL_ADD_KERNEL(soft_max);
|
||||||
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
@ -219,8 +239,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -229,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
||||||
@ -247,13 +270,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
#undef GGML_METAL_ADD_KERNEL
|
#undef GGML_METAL_ADD_KERNEL
|
||||||
}
|
}
|
||||||
|
|
||||||
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
||||||
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||||
|
#if TARGET_OS_OSX
|
||||||
|
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||||
if (ctx->device.maxTransferRate != 0) {
|
if (ctx->device.maxTransferRate != 0) {
|
||||||
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||||
} else {
|
} else {
|
||||||
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
@ -273,7 +298,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(relu);
|
GGML_METAL_DEL_KERNEL(relu);
|
||||||
GGML_METAL_DEL_KERNEL(gelu);
|
GGML_METAL_DEL_KERNEL(gelu);
|
||||||
GGML_METAL_DEL_KERNEL(soft_max);
|
GGML_METAL_DEL_KERNEL(soft_max);
|
||||||
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
||||||
|
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||||
@ -285,8 +313,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -295,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
||||||
@ -327,7 +358,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
|
|
||||||
void * ggml_metal_host_malloc(size_t n) {
|
void * ggml_metal_host_malloc(size_t n) {
|
||||||
void * data = NULL;
|
void * data = NULL;
|
||||||
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
||||||
return NULL;
|
return NULL;
|
||||||
@ -365,6 +396,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|||||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||||
|
|
||||||
|
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
||||||
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
||||||
*offs = (size_t) ioffs;
|
*offs = (size_t) ioffs;
|
||||||
|
|
||||||
@ -401,7 +433,7 @@ bool ggml_metal_add_buffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t size_page = getpagesize();
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
||||||
|
|
||||||
size_t size_aligned = size;
|
size_t size_aligned = size;
|
||||||
if ((size_aligned % size_page) != 0) {
|
if ((size_aligned % size_page) != 0) {
|
||||||
@ -454,6 +486,7 @@ bool ggml_metal_add_buffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if TARGET_OS_OSX
|
||||||
metal_printf(", (%8.2f / %8.2f)",
|
metal_printf(", (%8.2f / %8.2f)",
|
||||||
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||||
@ -463,6 +496,9 @@ bool ggml_metal_add_buffer(
|
|||||||
} else {
|
} else {
|
||||||
metal_printf("\n");
|
metal_printf("\n");
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -698,6 +734,7 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
// utilize float4
|
// utilize float4
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
@ -705,6 +742,7 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
// src1 is a row
|
// src1 is a row
|
||||||
|
GGML_ASSERT(ne11 == 1);
|
||||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||||
@ -721,6 +759,7 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
// utilize float4
|
// utilize float4
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
@ -728,6 +767,7 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
// src1 is a row
|
// src1 is a row
|
||||||
|
GGML_ASSERT(ne11 == 1);
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul];
|
[encoder setComputePipelineState:ctx->pipeline_mul];
|
||||||
@ -743,6 +783,8 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
const float scale = *(const float *) src1->data;
|
const float scale = *(const float *) src1->data;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
@ -750,7 +792,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -762,7 +804,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -782,7 +824,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -796,13 +838,16 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
|
if (ne00%4 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||||
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -810,14 +855,23 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
|
||||||
|
if (ne00%8 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
||||||
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||||
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
||||||
|
|
||||||
|
if (ne00%8 == 0) {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
}
|
||||||
|
else {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
@ -830,13 +884,14 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if (ggml_is_contiguous(src0) &&
|
if (!ggml_is_transposed(src0) &&
|
||||||
ggml_is_contiguous(src1) &&
|
!ggml_is_transposed(src1) &&
|
||||||
src1t == GGML_TYPE_F32 &&
|
src1t == GGML_TYPE_F32 &&
|
||||||
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
ne00%32 == 0 &&
|
ne00%32 == 0 &&
|
||||||
ne11 > 1) {
|
ne11 > 1) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
||||||
@ -856,25 +911,38 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||||
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
|
int nrows = 1;
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
||||||
|
nrows = 4;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
nth0 = 32;
|
nth0 = 32;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
if (ne11 * ne12 < 4) {
|
if (ne11 * ne12 < 4) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||||
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||||
|
nrows = ne11;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
|
nrows = 4;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
@ -995,7 +1063,7 @@ void ggml_metal_graph_compute(
|
|||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int64_t ny = (ne11 + 3)/4;
|
int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1003,6 +1071,7 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||||
@ -1018,9 +1087,9 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||||
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||||
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(src1);
|
const int64_t n = ggml_nelements(src1);
|
||||||
|
|
||||||
@ -1141,7 +1210,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
||||||
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
568
ggml-metal.metal
568
ggml-metal.metal
@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_scale(
|
kernel void kernel_scale(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] * scale;
|
dst[tpig] = src0[tpig] * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_silu(
|
kernel void kernel_silu(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
float x = src0[tpig];
|
device const float4 & x = src0[tpig];
|
||||||
dst[tpig] = x / (1.0f + exp(-x));
|
dst[tpig] = x / (1.0f + exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|||||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
kernel void kernel_gelu(
|
kernel void kernel_gelu(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
float x = src0[tpig];
|
device const float4 & x = src0[tpig];
|
||||||
|
|
||||||
// BEWARE !!!
|
// BEWARE !!!
|
||||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||||
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -119,64 +118,70 @@ kernel void kernel_soft_max(
|
|||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
buf[tpitg[0]] = -INFINITY;
|
float lmax = psrc0[tpitg[0]];
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
|
lmax = MAX(lmax, psrc0[i00]);
|
||||||
}
|
}
|
||||||
|
const float max = simd_max(lmax);
|
||||||
// reduce
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg[0] < i) {
|
|
||||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
|
|
||||||
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
|
||||||
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
|
||||||
//if (tpitg[0] == 0) {
|
|
||||||
// buf[0] = buf[0];
|
|
||||||
//}
|
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
const float max = buf[0];
|
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
buf[tpitg[0]] = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||||
buf[tpitg[0]] += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
// Remember the result of exp here. exp is expensive, so we really do not
|
// Remember the result of exp here. exp is expensive, so we really do not
|
||||||
// whish to compute it twice.
|
// whish to compute it twice.
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
const float sum = simd_sum(lsum);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg[0] < i) {
|
|
||||||
buf[tpitg[0]] += buf[tpitg[0] + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcast - not needed, see above
|
|
||||||
//// broadcast
|
|
||||||
//if (tpitg[0] == 0) {
|
|
||||||
// buf[0] = buf[0];
|
|
||||||
//}
|
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
const float sum = buf[0];
|
|
||||||
|
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
pdst[i00] /= sum;
|
pdst[i00] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_soft_max_4(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
|
// parallel max
|
||||||
|
float4 lmax4 = psrc4[tpitg[0]];
|
||||||
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||||
|
}
|
||||||
|
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
|
||||||
|
const float max = simd_max(lmax);
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
float4 lsum4 = 0.0f;
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||||
|
lsum4 += exp_psrc4;
|
||||||
|
pdst4[i00] = exp_psrc4;
|
||||||
|
}
|
||||||
|
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
|
|
||||||
|
const float sum = simd_sum(lsum);
|
||||||
|
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
pdst4[i00] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_diag_mask_inf(
|
kernel void kernel_diag_mask_inf(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -195,6 +200,33 @@ kernel void kernel_diag_mask_inf(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_diag_mask_inf_8(
|
||||||
|
device const float4 * src0,
|
||||||
|
device float4 * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int & n_past,
|
||||||
|
uint3 tpig[[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
const int64_t i = 2*tpig[0];
|
||||||
|
|
||||||
|
dst[i+0] = src0[i+0];
|
||||||
|
dst[i+1] = src0[i+1];
|
||||||
|
int64_t i4 = 4*i;
|
||||||
|
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
||||||
|
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
||||||
|
const int64_t i00 = i4;
|
||||||
|
for (int k = 3; k >= 0; --k) {
|
||||||
|
if (i00 + 4 + k <= n_past + i01) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
dst[i+1][k] = -INFINITY;
|
||||||
|
if (i00 + k > n_past + i01) {
|
||||||
|
dst[i][k] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -220,14 +252,10 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
const float mean = sum[0] / ne00;
|
||||||
//if (tpitg == 0) {
|
|
||||||
// sum[0] /= ne00;
|
|
||||||
//}
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
const float mean = sum[0];
|
|
||||||
|
|
||||||
// recenter and VARIANCE
|
// recenter and VARIANCE
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float * y = dst + tgpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
sum[tpitg] = 0.0f;
|
sum[tpitg] = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
@ -235,12 +263,6 @@ kernel void kernel_norm(
|
|||||||
sum[tpitg] += y[i00] * y[i00];
|
sum[tpitg] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
//// VARIANCE
|
|
||||||
//// parallel sum
|
|
||||||
//sum[tpitg] = 0.0f;
|
|
||||||
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
// sum[tpitg] += y[i00] * y[i00];
|
|
||||||
//}
|
|
||||||
// reduce
|
// reduce
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
@ -249,12 +271,7 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
const float variance = sum[0] / ne00;
|
||||||
//if (tpitg == 0) {
|
|
||||||
// sum[0] /= ne00;
|
|
||||||
//}
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
const float variance = sum[0];
|
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
const float scale = 1.0f/sqrt(variance + eps);
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
@ -262,7 +279,6 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
kernel void kernel_rms_norm(
|
kernel void kernel_rms_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -507,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define N_F32_F32 4
|
||||||
|
|
||||||
|
kernel void kernel_mul_mat_f32_f32(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t rb = tgpig.y*N_F32_F32;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
if (ne00 < 128) {
|
||||||
|
for (int row = 0; row < N_F32_F32; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
|
sumf += (float) x[i] * (float) y[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device const float4 * x4 = (device const float4 *)x;
|
||||||
|
for (int row = 0; row < N_F32_F32; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32_1row(
|
kernel void kernel_mul_mat_f16_f32_1row(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
@ -630,7 +719,49 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes row size (ne00) is a multiple of 4
|
||||||
|
kernel void kernel_mul_mat_f16_f32_l4(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
const int nrows = ne11;
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||||
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
@ -699,25 +830,27 @@ kernel void kernel_rope(
|
|||||||
constant int & mode,
|
constant int & mode,
|
||||||
constant float & freq_base,
|
constant float & freq_base,
|
||||||
constant float & freq_scale,
|
constant float & freq_scale,
|
||||||
uint3 tpig[[thread_position_in_grid]]) {
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
const int64_t i3 = tpig[2];
|
uint3 tptg[[threads_per_threadgroup]],
|
||||||
const int64_t i2 = tpig[1];
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int64_t i1 = tpig[0];
|
const int64_t i3 = tgpig[2];
|
||||||
|
const int64_t i2 = tgpig[1];
|
||||||
|
const int64_t i1 = tgpig[0];
|
||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
|
||||||
|
|
||||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||||
|
|
||||||
float theta = freq_scale * (float)p;
|
const float theta_0 = freq_scale * (float)p;
|
||||||
|
const float inv_ndims = -1.f/n_dims;
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||||
|
|
||||||
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
const float sin_theta = sin(theta);
|
const float sin_theta = sin(theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
|
||||||
|
|
||||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
@ -729,12 +862,12 @@ kernel void kernel_rope(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
||||||
|
|
||||||
|
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
const float sin_theta = sin(theta);
|
const float sin_theta = sin(theta);
|
||||||
|
|
||||||
theta *= theta_scale;
|
|
||||||
|
|
||||||
const int64_t i0 = ib*n_dims + ic/2;
|
const int64_t i0 = ib*n_dims + ic/2;
|
||||||
|
|
||||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
@ -1138,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
||||||
|
|
||||||
float yl[16];
|
float yl[32];
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x0303;
|
const uint16_t kmask1 = 0x3030;
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
|
|
||||||
const int tid = tiisg/2;
|
const int tid = tiisg/4;
|
||||||
const int ix = tiisg%2;
|
const int ix = tiisg%4;
|
||||||
const int ip = tid/8; // 0 or 1
|
const int ip = tid/4; // 0 or 1
|
||||||
const int il = tid/2 - 4*ip; // 0...3
|
const int il = 2*((tid%4)/2); // 0 or 2
|
||||||
const int ir = tid%2;
|
const int ir = tid%2;
|
||||||
const int n = 8;
|
const int n = 8;
|
||||||
const int l0 = n*ir;
|
const int l0 = n*ir;
|
||||||
|
|
||||||
const uint16_t m1 = 1 << (4*ip + il);
|
// One would think that the Metal compiler would figure out that ip and il can only have
|
||||||
const uint16_t m2 = m1 << 8;
|
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
||||||
|
// with these two tales.
|
||||||
|
//
|
||||||
|
// Possible masks for the high bit
|
||||||
|
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
||||||
|
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
||||||
|
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
||||||
|
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
||||||
|
|
||||||
|
// Possible masks for the low 2 bits
|
||||||
|
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
||||||
|
|
||||||
|
const ushort4 hm = mm[2*ip + il/2];
|
||||||
|
|
||||||
const int shift = 2*il;
|
const int shift = 2*il;
|
||||||
const uint16_t qm1 = 0x0003 << shift;
|
const float v1 = il == 0 ? 4.f : 64.f;
|
||||||
const uint16_t qm2 = 0x0300 << shift;
|
const float v2 = 4.f * v1;
|
||||||
const int32_t v1 = 4 << shift;
|
|
||||||
const int32_t v2 = 1024 << shift;
|
|
||||||
|
|
||||||
const uint16_t s_shift1 = 4*ip;
|
const uint16_t s_shift1 = 4*ip;
|
||||||
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
const uint16_t s_shift2 = s_shift1 + il;
|
||||||
const int ik = 4 + (il%2);
|
|
||||||
|
|
||||||
const int q_offset = 32*ip + l0;
|
const int q_offset = 32*ip + l0;
|
||||||
const int y_offset = 128*ip + 32*il + l0;
|
const int y_offset = 128*ip + 32*il + l0;
|
||||||
@ -1171,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
|
|
||||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||||
|
|
||||||
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
uint32_t scales32, aux32;
|
||||||
for (int i = ix; i < nb; i += 2) {
|
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
||||||
|
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
||||||
|
|
||||||
|
float sumf1[2] = {0.f};
|
||||||
|
float sumf2[2] = {0.f};
|
||||||
|
for (int i = ix; i < nb; i += 4) {
|
||||||
|
|
||||||
for (int l = 0; l < 8; ++l) {
|
for (int l = 0; l < 8; ++l) {
|
||||||
yl[l+ 0] = y1[l+ 0];
|
yl[l+ 0] = y1[l+ 0];
|
||||||
yl[l+ 8] = y1[l+16];
|
yl[l+ 8] = y1[l+16];
|
||||||
|
yl[l+16] = y1[l+32];
|
||||||
|
yl[l+24] = y1[l+48];
|
||||||
}
|
}
|
||||||
|
|
||||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
||||||
@ -1187,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
const float d_all = (float)dh[0];
|
const float d_all = (float)dh[0];
|
||||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
|
||||||
|
|
||||||
float s1 = 0, s2 = 0;
|
scales16[0] = a[4];
|
||||||
for (int l = 0; l < n; l += 2) {
|
scales16[1] = a[5];
|
||||||
const uint16_t qs = q[l/2];
|
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
||||||
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
scales16[0] = a[il+0];
|
||||||
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
scales16[1] = a[il+1];
|
||||||
}
|
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
||||||
float d = d_all * (s1 + 1.f/256.f * s2);
|
|
||||||
sumf1[row] += d * scales[0];
|
|
||||||
sumf2[row] += d;
|
|
||||||
|
|
||||||
s1 = s2 = 0;
|
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const uint16_t qs = q[l/2+8];
|
const int32_t qs = q[l/2];
|
||||||
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
||||||
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
||||||
|
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
||||||
|
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
||||||
|
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
||||||
|
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
||||||
}
|
}
|
||||||
d = d_all * (s1 + 1.f/256.f * s2);
|
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
sumf1[row] += d * scales[1];
|
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
||||||
sumf2[row] += d;
|
sumf1[row] += d1 * (scales[0] - 32);
|
||||||
|
sumf2[row] += d2 * (scales[2] - 32);
|
||||||
|
|
||||||
|
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
||||||
|
for (int l = 0; l < n; l += 2) {
|
||||||
|
const int32_t qs = q[l/2+8];
|
||||||
|
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
||||||
|
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
||||||
|
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
||||||
|
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
||||||
|
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
||||||
|
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
||||||
|
}
|
||||||
|
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||||
|
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
||||||
|
sumf1[row] += d1 * (scales[1] - 32);
|
||||||
|
sumf2[row] += d2 * (scales[3] - 32);
|
||||||
|
|
||||||
q += step;
|
q += step;
|
||||||
h += step;
|
h += step;
|
||||||
@ -1216,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
y1 += 2 * QK_K;
|
y1 += 4 * QK_K;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
||||||
const float tot = simd_sum(sumf);
|
sumf1[row] = simd_sum(sumf);
|
||||||
|
}
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1579,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||||||
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
||||||
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
||||||
|
|
||||||
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
float4 acc1 = {0.f};
|
||||||
|
float4 acc2 = {0.f};
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
uint8_t h = qh[l];
|
uint8_t h = qh[l];
|
||||||
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||||
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||||
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
||||||
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
||||||
|
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
||||||
|
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
||||||
|
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
||||||
|
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
||||||
}
|
}
|
||||||
const float dall = dh[0];
|
const float dall = dh[0];
|
||||||
const float dmin = dh[1];
|
const float dmin = dh[1];
|
||||||
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
||||||
|
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
||||||
|
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
||||||
|
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
||||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||||
|
|
||||||
q1 += step;
|
q1 += step;
|
||||||
@ -1762,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||||||
|
|
||||||
//============================= templates and their specializations =============================
|
//============================= templates and their specializations =============================
|
||||||
|
|
||||||
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||||
|
template <typename type4x4>
|
||||||
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||||
|
float4x4 temp = *(((device float4x4 *)src));
|
||||||
|
for (int i = 0; i < 16; i++){
|
||||||
|
reg[i/4][i%4] = temp[i/4][i%4];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||||
half4x4 temp = *(((device half4x4 *)src));
|
half4x4 temp = *(((device half4x4 *)src));
|
||||||
@ -1773,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||||
const half m = il ? ( -8.h * 16.h) : -8.h;
|
const float d2 = d1 / 256.f;
|
||||||
|
const float md = -8.h * xb->d;
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||||
const half m = xb->m;
|
const float d2 = d1 / 256.f;
|
||||||
|
const float m = xb->m;
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1830,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d_all = (float)(xb->d);
|
const half d_all = xb->d;
|
||||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||||
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
||||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||||
@ -1843,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|||||||
((il/4)>0 ? 12 : 3);
|
((il/4)>0 ? 12 : 3);
|
||||||
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
||||||
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
||||||
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
|
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
||||||
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
||||||
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
||||||
|
const half ml = 4.h * dl;
|
||||||
|
|
||||||
il = (il/2)%4;
|
il = (il/2) & 3;
|
||||||
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||||
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||||
|
dl *= coef;
|
||||||
|
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
||||||
@ -1867,19 +2055,24 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
||||||
|
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
||||||
|
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint8_t * q = xb->qs;
|
device const uchar * q = xb->qs;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const float d = (float)(xb->d);
|
|
||||||
const float min = (float)(xb->dmin);
|
|
||||||
short is = (il/4) * 2;
|
short is = (il/4) * 2;
|
||||||
q = q + (il/4) * 32 + 16 * (il&1);
|
q = q + (il/4) * 32 + 16 * (il&1);
|
||||||
il = il%4;
|
il = il & 3;
|
||||||
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||||
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||||
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
const half min = xb->dmin;
|
||||||
|
const half dl = d * sc[0];
|
||||||
|
const half ml = min * sc[1];
|
||||||
#else
|
#else
|
||||||
q = q + 16 * (il&1);
|
q = q + 16 * (il&1);
|
||||||
device const uint8_t * s = xb->scales;
|
device const uint8_t * s = xb->scales;
|
||||||
@ -1900,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|||||||
device const uint8_t * qh = xb->qh;
|
device const uint8_t * qh = xb->qh;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const float d = (float)(xb->d);
|
|
||||||
const float min = (float)(xb->dmin);
|
|
||||||
short is = (il/4) * 2;
|
short is = (il/4) * 2;
|
||||||
q = q + 32 * (il/4) + 16 * (il&1);
|
q = q + 32 * (il/4) + 16 * (il&1);
|
||||||
qh = qh + 16 * (il&1);
|
qh = qh + 16 * (il&1);
|
||||||
uint8_t ul = 1 << (il/2);
|
uint8_t ul = 1 << (il/2);
|
||||||
il = il%4;
|
il = il & 3;
|
||||||
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||||
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||||
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
const half min = xb->dmin;
|
||||||
|
const half dl = d * sc[0];
|
||||||
|
const half ml = min * sc[1];
|
||||||
|
|
||||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||||
const float qh_val = il<2 ? 16.f : 256.f;
|
const half qh_val = il<2 ? 16.h : 256.h;
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||||
}
|
}
|
||||||
@ -1931,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d_all = (float)(xb->d);
|
const half d_all = xb->d;
|
||||||
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
||||||
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
||||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||||
@ -1939,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
||||||
qh = qh + 32*(il/8) + 16*(il&1);
|
qh = qh + 32*(il/8) + 16*(il&1);
|
||||||
float sc = scales[(il%2) + 2 * ((il/2))];
|
half sc = scales[(il%2) + 2 * ((il/2))];
|
||||||
il = (il/2)%4;
|
il = (il/2) & 3;
|
||||||
#else
|
#else
|
||||||
ql = ql + 16 * (il&1);
|
ql = ql + 16 * (il&1);
|
||||||
float sc = scales[il];
|
half sc = scales[il];
|
||||||
#endif
|
#endif
|
||||||
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||||
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
||||||
|
const half coef = il>1 ? 1.f/16.h : 1.h;
|
||||||
|
const half ml = d_all * sc * 32.h;
|
||||||
|
const half dl = d_all * sc * coef;
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
||||||
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
||||||
const float coef = il>1 ? 1.f/16.f : 1.f;
|
reg[i/4][i%4] = dl * q - ml;
|
||||||
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
|
||||||
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
|
||||||
reg[i/4][i%4] = d_all * sc * q * coef;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1991,13 +2186,16 @@ kernel void kernel_get_rows(
|
|||||||
// each block_q contains 16*nl weights
|
// each block_q contains 16*nl weights
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
kernel void kernel_mul_mm(device const uchar * src0,
|
kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
device const float * src1,
|
device const uchar * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant int64_t & nb01,
|
constant int64_t & nb01,
|
||||||
constant int64_t & nb02,
|
constant int64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & nb10,
|
||||||
|
constant int64_t & nb11,
|
||||||
|
constant int64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & gqa,
|
constant uint & gqa,
|
||||||
@ -2006,7 +2204,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
threadgroup half * sa = ((threadgroup half *)shared_memory);
|
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||||
|
|
||||||
const uint r0 = tgpig.y;
|
const uint r0 = tgpig.y;
|
||||||
@ -2027,10 +2225,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
}
|
}
|
||||||
|
|
||||||
short il = (tiitg % THREAD_PER_ROW);
|
short il = (tiitg % THREAD_PER_ROW);
|
||||||
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
|
||||||
|
uint offset0 = im/gqa*nb02;
|
||||||
|
ushort offset1 = il/nl;
|
||||||
|
|
||||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||||
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
device const float * y = (device const float *)(src1
|
||||||
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
|
+ nb12 * im
|
||||||
|
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
||||||
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||||
|
|
||||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||||
//load data and store to threadgroup memory
|
//load data and store to threadgroup memory
|
||||||
@ -2110,6 +2313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
||||||
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
||||||
|
|
||||||
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||||
@ -2120,10 +2324,24 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|||||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
|
|
||||||
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
typedef void (mat_mm_t)(
|
||||||
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
device const uchar * src0,
|
||||||
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
device const uchar * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & nb01,
|
||||||
|
constant int64_t & nb02,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & nb10,
|
||||||
|
constant int64_t & nb11,
|
||||||
|
constant int64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & gqa,
|
||||||
|
threadgroup uchar *, uint3, uint, uint);
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||||
|
137
ggml.c
137
ggml.c
@ -1,4 +1,3 @@
|
|||||||
#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
|
|
||||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
@ -107,6 +106,9 @@ typedef void * thread_ret_t;
|
|||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
#ifdef GGML_USE_CPU_HBM
|
||||||
|
#include <hbwmalloc.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
||||||
@ -196,9 +198,15 @@ typedef void * thread_ret_t;
|
|||||||
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
||||||
#else
|
#else
|
||||||
inline static void * ggml_aligned_malloc(size_t size) {
|
inline static void * ggml_aligned_malloc(size_t size) {
|
||||||
|
if (size == 0) {
|
||||||
|
GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
void * aligned_memory = NULL;
|
void * aligned_memory = NULL;
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_CPU_HBM
|
||||||
int result = posix_memalign(&aligned_memory, getpagesize(), size);
|
int result = hbw_posix_memalign(&aligned_memory, 16, size);
|
||||||
|
#elif GGML_USE_METAL
|
||||||
|
int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
|
||||||
#else
|
#else
|
||||||
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
||||||
#endif
|
#endif
|
||||||
@ -219,8 +227,12 @@ inline static void * ggml_aligned_malloc(size_t size) {
|
|||||||
return aligned_memory;
|
return aligned_memory;
|
||||||
}
|
}
|
||||||
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
||||||
|
#ifdef GGML_USE_CPU_HBM
|
||||||
|
#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
|
||||||
|
#else
|
||||||
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#define UNUSED GGML_UNUSED
|
#define UNUSED GGML_UNUSED
|
||||||
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
||||||
@ -4291,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||||
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
|
size_t nbytes;
|
||||||
|
size_t blck_size = ggml_blck_size(tensor->type);
|
||||||
|
if (blck_size == 1) {
|
||||||
|
nbytes = ggml_type_size(tensor->type);
|
||||||
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||||
|
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
||||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nbytes;
|
return nbytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4572,6 +4595,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allow to call ggml_init with 0 size
|
||||||
|
if (params.mem_size == 0) {
|
||||||
|
params.mem_size = GGML_MEM_ALIGN;
|
||||||
|
}
|
||||||
|
|
||||||
const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
|
const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
|
||||||
|
|
||||||
*ctx = (struct ggml_context) {
|
*ctx = (struct ggml_context) {
|
||||||
@ -4774,7 +4802,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||||||
|
|
||||||
size_t obj_alloc_size = 0;
|
size_t obj_alloc_size = 0;
|
||||||
|
|
||||||
if (view_src == NULL && ctx->no_alloc == false) {
|
if (view_src == NULL && !ctx->no_alloc) {
|
||||||
if (ctx->scratch.data != NULL) {
|
if (ctx->scratch.data != NULL) {
|
||||||
// allocate tensor data in the scratch buffer
|
// allocate tensor data in the scratch buffer
|
||||||
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
|
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
|
||||||
@ -5475,7 +5503,7 @@ static struct ggml_tensor * ggml_mul_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (inplace) {
|
if (inplace) {
|
||||||
GGML_ASSERT(is_node == false);
|
GGML_ASSERT(!is_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
@ -5518,7 +5546,7 @@ static struct ggml_tensor * ggml_div_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (inplace) {
|
if (inplace) {
|
||||||
GGML_ASSERT(is_node == false);
|
GGML_ASSERT(!is_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
@ -17266,10 +17294,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|||||||
} else {
|
} else {
|
||||||
// wait for other threads to finish
|
// wait for other threads to finish
|
||||||
const int last = node_n;
|
const int last = node_n;
|
||||||
do {
|
while (true) {
|
||||||
//sched_yield();
|
// TODO: this sched_yield can have significant impact on the performance - either positive or negative
|
||||||
|
// depending on the workload and the operating system.
|
||||||
|
// since it is not clear what is the best approach, it should potentially become user-configurable
|
||||||
|
// ref: https://github.com/ggerganov/ggml/issues/291
|
||||||
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||||
|
sched_yield();
|
||||||
|
#endif
|
||||||
|
|
||||||
node_n = atomic_load(&state->shared->node_n);
|
node_n = atomic_load(&state->shared->node_n);
|
||||||
} while (node_n == last);
|
if (node_n != last) break;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we should stop
|
// check if we should stop
|
||||||
@ -18320,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|||||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||||
struct ggml_tensor * node = cgraph->leafs[i];
|
struct ggml_tensor * node = cgraph->leafs[i];
|
||||||
|
|
||||||
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
|
||||||
i,
|
i,
|
||||||
node->ne[0], node->ne[1],
|
node->ne[0], node->ne[1],
|
||||||
ggml_op_name(node->op));
|
ggml_op_name(node->op),
|
||||||
|
ggml_get_name(node));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
||||||
@ -19962,7 +19999,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
|
|
||||||
struct ggml_tensor * data = NULL;
|
struct ggml_tensor * data = NULL;
|
||||||
|
|
||||||
if (params.no_alloc == false) {
|
if (!params.no_alloc) {
|
||||||
data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
|
data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
|
||||||
|
|
||||||
ok = ok && data != NULL;
|
ok = ok && data != NULL;
|
||||||
@ -20003,7 +20040,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// point the data member to the appropriate location in the binary blob using the tensor infos
|
// point the data member to the appropriate location in the binary blob using the tensor infos
|
||||||
if (params.no_alloc == false) {
|
if (!params.no_alloc) {
|
||||||
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
|
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
|
||||||
cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
|
cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
|
||||||
}
|
}
|
||||||
@ -20082,27 +20119,27 @@ const char * gguf_type_name(enum gguf_type type) {
|
|||||||
return GGUF_TYPE_NAME[type];
|
return GGUF_TYPE_NAME[type];
|
||||||
}
|
}
|
||||||
|
|
||||||
int gguf_get_version(struct gguf_context * ctx) {
|
int gguf_get_version(const struct gguf_context * ctx) {
|
||||||
return ctx->header.version;
|
return ctx->header.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t gguf_get_alignment(struct gguf_context * ctx) {
|
size_t gguf_get_alignment(const struct gguf_context * ctx) {
|
||||||
return ctx->alignment;
|
return ctx->alignment;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t gguf_get_data_offset(struct gguf_context * ctx) {
|
size_t gguf_get_data_offset(const struct gguf_context * ctx) {
|
||||||
return ctx->offset;
|
return ctx->offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
void * gguf_get_data(struct gguf_context * ctx) {
|
void * gguf_get_data(const struct gguf_context * ctx) {
|
||||||
return ctx->data;
|
return ctx->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gguf_get_n_kv(struct gguf_context * ctx) {
|
int gguf_get_n_kv(const struct gguf_context * ctx) {
|
||||||
return ctx->header.n_kv;
|
return ctx->header.n_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gguf_find_key(struct gguf_context * ctx, const char * key) {
|
int gguf_find_key(const struct gguf_context * ctx, const char * key) {
|
||||||
// return -1 if key not found
|
// return -1 if key not found
|
||||||
int keyfound = -1;
|
int keyfound = -1;
|
||||||
|
|
||||||
@ -20118,85 +20155,85 @@ int gguf_find_key(struct gguf_context * ctx, const char * key) {
|
|||||||
return keyfound;
|
return keyfound;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * gguf_get_key(struct gguf_context * ctx, int i) {
|
const char * gguf_get_key(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].key.data;
|
return ctx->kv[i].key.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
|
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].type;
|
return ctx->kv[i].type;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
|
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.arr.type;
|
return ctx->kv[i].value.arr.type;
|
||||||
}
|
}
|
||||||
|
|
||||||
const void * gguf_get_arr_data(struct gguf_context * ctx, int i) {
|
const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.arr.data;
|
return ctx->kv[i].value.arr.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
|
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
|
||||||
struct gguf_kv * kv = &ctx->kv[key_id];
|
struct gguf_kv * kv = &ctx->kv[key_id];
|
||||||
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
|
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
|
||||||
return str->data;
|
return str->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gguf_get_arr_n(struct gguf_context * ctx, int i) {
|
int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.arr.n;
|
return ctx->kv[i].value.arr.n;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
|
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.uint8;
|
return ctx->kv[i].value.uint8;
|
||||||
}
|
}
|
||||||
|
|
||||||
int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
|
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.int8;
|
return ctx->kv[i].value.int8;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
|
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.uint16;
|
return ctx->kv[i].value.uint16;
|
||||||
}
|
}
|
||||||
|
|
||||||
int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
|
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.int16;
|
return ctx->kv[i].value.int16;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
|
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.uint32;
|
return ctx->kv[i].value.uint32;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
|
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.int32;
|
return ctx->kv[i].value.int32;
|
||||||
}
|
}
|
||||||
|
|
||||||
float gguf_get_val_f32(struct gguf_context * ctx, int i) {
|
float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.float32;
|
return ctx->kv[i].value.float32;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
|
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.uint64;
|
return ctx->kv[i].value.uint64;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
|
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.int64;
|
return ctx->kv[i].value.int64;
|
||||||
}
|
}
|
||||||
|
|
||||||
double gguf_get_val_f64(struct gguf_context * ctx, int i) {
|
double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.float64;
|
return ctx->kv[i].value.float64;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
|
bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.bool_;
|
return ctx->kv[i].value.bool_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
|
const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
|
||||||
return ctx->kv[i].value.str.data;
|
return ctx->kv[i].value.str.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gguf_get_n_tensors(struct gguf_context * ctx) {
|
int gguf_get_n_tensors(const struct gguf_context * ctx) {
|
||||||
return ctx->header.n_tensors;
|
return ctx->header.n_tensors;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
|
int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
|
||||||
// return -1 if tensor not found
|
// return -1 if tensor not found
|
||||||
int tensorfound = -1;
|
int tensorfound = -1;
|
||||||
|
|
||||||
@ -20212,11 +20249,11 @@ int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
|
|||||||
return tensorfound;
|
return tensorfound;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) {
|
size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->infos[i].offset;
|
return ctx->infos[i].offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
char * gguf_get_tensor_name(struct gguf_context * ctx, int i) {
|
char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
|
||||||
return ctx->infos[i].name.data;
|
return ctx->infos[i].name.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -20499,7 +20536,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
|
|||||||
buf->offset += el_size;
|
buf->offset += el_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
|
static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
|
||||||
// write header
|
// write header
|
||||||
gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
|
gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
|
||||||
gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
|
gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
|
||||||
@ -20614,7 +20651,7 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) {
|
void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
|
||||||
FILE * file = fopen(fname, "wb");
|
FILE * file = fopen(fname, "wb");
|
||||||
if (!file) {
|
if (!file) {
|
||||||
GGML_ASSERT(false && "failed to open file for writing");
|
GGML_ASSERT(false && "failed to open file for writing");
|
||||||
@ -20631,7 +20668,7 @@ void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only
|
|||||||
fclose(file);
|
fclose(file);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t gguf_get_meta_size(struct gguf_context * ctx) {
|
size_t gguf_get_meta_size(const struct gguf_context * ctx) {
|
||||||
// no allocs - only compute size
|
// no allocs - only compute size
|
||||||
struct gguf_buf buf = gguf_buf_init(0);
|
struct gguf_buf buf = gguf_buf_init(0);
|
||||||
|
|
||||||
@ -20640,7 +20677,7 @@ size_t gguf_get_meta_size(struct gguf_context * ctx) {
|
|||||||
return buf.offset;
|
return buf.offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
void gguf_get_meta_data(struct gguf_context * ctx, void * data) {
|
void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
|
||||||
struct gguf_buf buf = gguf_buf_init(16*1024);
|
struct gguf_buf buf = gguf_buf_init(16*1024);
|
||||||
|
|
||||||
gguf_write_to_buf(ctx, &buf, true);
|
gguf_write_to_buf(ctx, &buf, true);
|
||||||
@ -20716,6 +20753,14 @@ int ggml_cpu_has_arm_fma(void) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int ggml_cpu_has_metal(void) {
|
||||||
|
#if defined(GGML_USE_METAL)
|
||||||
|
return 1;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_cpu_has_f16c(void) {
|
int ggml_cpu_has_f16c(void) {
|
||||||
#if defined(__F16C__)
|
#if defined(__F16C__)
|
||||||
return 1;
|
return 1;
|
||||||
|
72
ggml.h
72
ggml.h
@ -195,6 +195,14 @@
|
|||||||
# define GGML_DEPRECATED(func, hint) func
|
# define GGML_DEPRECATED(func, hint) func
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef __GNUC__
|
||||||
|
# define GGML_ATTRIBUTE_FORMAT(...)
|
||||||
|
#elif defined(__MINGW32__)
|
||||||
|
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
|
#else
|
||||||
|
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
@ -685,6 +693,7 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
||||||
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
||||||
|
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||||
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -1866,39 +1875,39 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_API const char * gguf_type_name(enum gguf_type type);
|
GGML_API const char * gguf_type_name(enum gguf_type type);
|
||||||
|
|
||||||
GGML_API int gguf_get_version (struct gguf_context * ctx);
|
GGML_API int gguf_get_version (const struct gguf_context * ctx);
|
||||||
GGML_API size_t gguf_get_alignment (struct gguf_context * ctx);
|
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
||||||
GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
|
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
||||||
GGML_API void * gguf_get_data (struct gguf_context * ctx);
|
GGML_API void * gguf_get_data (const struct gguf_context * ctx);
|
||||||
|
|
||||||
GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
|
GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
|
||||||
GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
|
GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
|
||||||
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
|
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
|
||||||
|
|
||||||
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
|
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
|
||||||
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
|
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
|
||||||
|
|
||||||
// results are undefined if the wrong type is used for the key
|
// results are undefined if the wrong type is used for the key
|
||||||
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
|
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
|
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
|
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i);
|
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i);
|
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i);
|
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i);
|
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i);
|
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i);
|
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i);
|
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
|
||||||
GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i);
|
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
|
||||||
GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
|
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
|
||||||
GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i);
|
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
|
||||||
GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);
|
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
|
||||||
GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);
|
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
||||||
|
|
||||||
GGML_API int gguf_get_n_tensors (struct gguf_context * ctx);
|
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
|
||||||
GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name);
|
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
|
||||||
GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);
|
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
|
||||||
GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i);
|
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
|
||||||
|
|
||||||
// overrides existing values or adds a new one
|
// overrides existing values or adds a new one
|
||||||
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
||||||
@ -1943,11 +1952,11 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
|
|
||||||
// write the entire context to a binary file
|
// write the entire context to a binary file
|
||||||
GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);
|
GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
|
||||||
|
|
||||||
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
||||||
GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);
|
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
||||||
GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data);
|
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
||||||
|
|
||||||
//
|
//
|
||||||
// system info
|
// system info
|
||||||
@ -1961,6 +1970,7 @@ extern "C" {
|
|||||||
GGML_API int ggml_cpu_has_fma (void);
|
GGML_API int ggml_cpu_has_fma (void);
|
||||||
GGML_API int ggml_cpu_has_neon (void);
|
GGML_API int ggml_cpu_has_neon (void);
|
||||||
GGML_API int ggml_cpu_has_arm_fma (void);
|
GGML_API int ggml_cpu_has_arm_fma (void);
|
||||||
|
GGML_API int ggml_cpu_has_metal (void);
|
||||||
GGML_API int ggml_cpu_has_f16c (void);
|
GGML_API int ggml_cpu_has_f16c (void);
|
||||||
GGML_API int ggml_cpu_has_fp16_va (void);
|
GGML_API int ggml_cpu_has_fp16_va (void);
|
||||||
GGML_API int ggml_cpu_has_wasm_simd (void);
|
GGML_API int ggml_cpu_has_wasm_simd (void);
|
||||||
|
@ -1,57 +0,0 @@
|
|||||||
# - "turn on lights."
|
|
||||||
# - "set thermostat to 22."
|
|
||||||
# - "increase TV by 10."
|
|
||||||
# - "decrease oven by 50."
|
|
||||||
# - "play music."
|
|
||||||
# - "stop podcast."
|
|
||||||
# - "schedule cleaning at 3pm."
|
|
||||||
# - "cancel cleaning."
|
|
||||||
# - "remind me to buy milk at 5pm."
|
|
||||||
# - "show me security system."
|
|
||||||
# - "hide washing machine."
|
|
||||||
# - "what is the lights status?"
|
|
||||||
# - "what is the current thermostat value?"
|
|
||||||
# - "what is the security system status?"
|
|
||||||
# - "what is the door lock status?"
|
|
||||||
# - "what is the camera battery level?"
|
|
||||||
# - "what is the weather like today?"
|
|
||||||
# - "what is the forecast for tomorrow?"
|
|
||||||
# - "what is the time?"
|
|
||||||
# - "what is my schedule for today?"
|
|
||||||
# - "what tasks do I have?"
|
|
||||||
# - "what reminders do I have?"
|
|
||||||
#
|
|
||||||
# example:
|
|
||||||
#
|
|
||||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
|
|
||||||
#
|
|
||||||
|
|
||||||
root ::= init " " (command | question) "."
|
|
||||||
prompt ::= init
|
|
||||||
|
|
||||||
# leading space is very important!
|
|
||||||
init ::= " Ok Whisper, start listening for commands."
|
|
||||||
|
|
||||||
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
|
|
||||||
"Increase " device " by " value | "Decrease " device " by " value |
|
|
||||||
"Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
|
|
||||||
"Remind me to " task " at " time | "Show me " device | "Hide " device
|
|
||||||
|
|
||||||
question ::= "What is the " device " status?" | "What is the current " device " value?" |
|
|
||||||
"What is the " device " temperature?" | "What is the " device " humidity?" |
|
|
||||||
"What is the " device " power consumption?" | "What is the " device " battery level?" |
|
|
||||||
"What is the weather like today?" | "What is the forecast for tomorrow?" |
|
|
||||||
"What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
|
|
||||||
"What reminders do I have?"
|
|
||||||
|
|
||||||
device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
|
|
||||||
"music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
|
|
||||||
"vacuum cleaner"
|
|
||||||
|
|
||||||
value ::= [0-9]+
|
|
||||||
|
|
||||||
media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
|
|
||||||
|
|
||||||
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
|
|
||||||
|
|
||||||
time ::= [0-9] [0-9]? ("am" | "pm")?
|
|
@ -1,29 +0,0 @@
|
|||||||
# - bishop to c3
|
|
||||||
# - rook to d4
|
|
||||||
# - knight to e5
|
|
||||||
# - d4 d5 knight to c3
|
|
||||||
# - c3 queen to d4 king b1
|
|
||||||
# - pawn to a1 bishop to b2 knight to c3
|
|
||||||
#
|
|
||||||
# The prompt (--prompt) is the initial phrase that the user has to say.
|
|
||||||
# This is used to prime Whisper with how the user is expected to speak.
|
|
||||||
#
|
|
||||||
# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
|
|
||||||
# Longer context is better, but it slightly increases the processing time.
|
|
||||||
#
|
|
||||||
# example:
|
|
||||||
#
|
|
||||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
|
|
||||||
#
|
|
||||||
|
|
||||||
root ::= init move move? move? "."
|
|
||||||
prompt ::= init "."
|
|
||||||
|
|
||||||
# leading space is very important!
|
|
||||||
init ::= " rook to b4, f3"
|
|
||||||
|
|
||||||
move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
|
||||||
|
|
||||||
piece ::= "bishop" | "rook" | "knight" | "queen"
|
|
||||||
king ::= "king"
|
|
||||||
pawn ::= "pawn"
|
|
@ -1,16 +0,0 @@
|
|||||||
# - red
|
|
||||||
# - green
|
|
||||||
# - blue
|
|
||||||
#
|
|
||||||
# example:
|
|
||||||
#
|
|
||||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
|
|
||||||
#
|
|
||||||
|
|
||||||
root ::= init color "."
|
|
||||||
prompt ::= init "."
|
|
||||||
|
|
||||||
# leading space is very important!
|
|
||||||
init ::= " red, green, blue"
|
|
||||||
|
|
||||||
color ::= ", " ("red" | "green" | "blue")
|
|
117
models/convert-h5-to-coreml.py
Normal file
117
models/convert-h5-to-coreml.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import argparse
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py')
|
||||||
|
whisper_to_coreml = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(whisper_to_coreml)
|
||||||
|
|
||||||
|
from whisper import load_model
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
from transformers import WhisperForConditionalGeneration
|
||||||
|
from huggingface_hub import metadata_update
|
||||||
|
|
||||||
|
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
||||||
|
WHISPER_MAPPING = {
|
||||||
|
"layers": "blocks",
|
||||||
|
"fc1": "mlp.0",
|
||||||
|
"fc2": "mlp.2",
|
||||||
|
"final_layer_norm": "mlp_ln",
|
||||||
|
"layers": "blocks",
|
||||||
|
".self_attn.q_proj": ".attn.query",
|
||||||
|
".self_attn.k_proj": ".attn.key",
|
||||||
|
".self_attn.v_proj": ".attn.value",
|
||||||
|
".self_attn_layer_norm": ".attn_ln",
|
||||||
|
".self_attn.out_proj": ".attn.out",
|
||||||
|
".encoder_attn.q_proj": ".cross_attn.query",
|
||||||
|
".encoder_attn.k_proj": ".cross_attn.key",
|
||||||
|
".encoder_attn.v_proj": ".cross_attn.value",
|
||||||
|
".encoder_attn_layer_norm": ".cross_attn_ln",
|
||||||
|
".encoder_attn.out_proj": ".cross_attn.out",
|
||||||
|
"decoder.layer_norm.": "decoder.ln.",
|
||||||
|
"encoder.layer_norm.": "encoder.ln_post.",
|
||||||
|
"embed_tokens": "token_embedding",
|
||||||
|
"encoder.embed_positions.weight": "encoder.positional_embedding",
|
||||||
|
"decoder.embed_positions.weight": "decoder.positional_embedding",
|
||||||
|
"layer_norm": "ln_post",
|
||||||
|
}
|
||||||
|
|
||||||
|
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
||||||
|
def rename_keys(s_dict):
|
||||||
|
keys = list(s_dict.keys())
|
||||||
|
for key in keys:
|
||||||
|
new_key = key
|
||||||
|
for k, v in WHISPER_MAPPING.items():
|
||||||
|
if k in key:
|
||||||
|
new_key = new_key.replace(k, v)
|
||||||
|
|
||||||
|
print(f"{key} -> {new_key}")
|
||||||
|
|
||||||
|
s_dict[new_key] = s_dict.pop(key)
|
||||||
|
return s_dict
|
||||||
|
|
||||||
|
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
||||||
|
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
||||||
|
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
||||||
|
config = transformer_model.config
|
||||||
|
|
||||||
|
# first build dims
|
||||||
|
dims = {
|
||||||
|
'n_mels': config.num_mel_bins,
|
||||||
|
'n_vocab': config.vocab_size,
|
||||||
|
'n_audio_ctx': config.max_source_positions,
|
||||||
|
'n_audio_state': config.d_model,
|
||||||
|
'n_audio_head': config.encoder_attention_heads,
|
||||||
|
'n_audio_layer': config.encoder_layers,
|
||||||
|
'n_text_ctx': config.max_target_positions,
|
||||||
|
'n_text_state': config.d_model,
|
||||||
|
'n_text_head': config.decoder_attention_heads,
|
||||||
|
'n_text_layer': config.decoder_layers
|
||||||
|
}
|
||||||
|
|
||||||
|
state_dict = deepcopy(transformer_model.model.state_dict())
|
||||||
|
state_dict = rename_keys(state_dict)
|
||||||
|
|
||||||
|
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
|
||||||
|
|
||||||
|
# Ported from models/convert-whisper-to-coreml.py
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
|
||||||
|
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
|
||||||
|
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
||||||
|
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
||||||
|
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
|
||||||
|
raise ValueError("Invalid model name")
|
||||||
|
|
||||||
|
pt_target_path = f"models/hf-{args.model_name}.pt"
|
||||||
|
convert_hf_whisper(args.model_path, pt_target_path)
|
||||||
|
|
||||||
|
whisper = load_model(pt_target_path).cpu()
|
||||||
|
hparams = whisper.dims
|
||||||
|
print(hparams)
|
||||||
|
|
||||||
|
if args.optimize_ane:
|
||||||
|
whisperANE = whisper_to_coreml.WhisperANE(hparams).eval()
|
||||||
|
whisperANE.load_state_dict(whisper.state_dict())
|
||||||
|
|
||||||
|
encoder = whisperANE.encoder
|
||||||
|
decoder = whisperANE.decoder
|
||||||
|
else:
|
||||||
|
encoder = whisper.encoder
|
||||||
|
decoder = whisper.decoder
|
||||||
|
|
||||||
|
# Convert encoder
|
||||||
|
encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize)
|
||||||
|
encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage")
|
||||||
|
|
||||||
|
if args.encoder_only is False:
|
||||||
|
# Convert decoder
|
||||||
|
decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize)
|
||||||
|
decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage")
|
||||||
|
|
||||||
|
print("done converting")
|
@ -29,7 +29,7 @@ def convert_encoder(hparams, encoder, mname):
|
|||||||
|
|
||||||
# use model optimizer to convert onnx to OpenVINO IR format
|
# use model optimizer to convert onnx to OpenVINO IR format
|
||||||
encoder_model = mo.convert_model(onnx_path, compress_to_fp16=True)
|
encoder_model = mo.convert_model(onnx_path, compress_to_fp16=True)
|
||||||
serialize(encoder_model, xml_path='ggml-' + mname + '-encoder-openvino.xml')
|
serialize(encoder_model, xml_path=os.path.join(os.path.dirname(__file__),"ggml-" + mname + "-encoder-openvino.xml"))
|
||||||
|
|
||||||
#cleanup
|
#cleanup
|
||||||
if os.path.isdir(onnx_folder):
|
if os.path.isdir(onnx_folder):
|
||||||
|
@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
|
|||||||
goto :eof
|
goto :eof
|
||||||
)
|
)
|
||||||
|
|
||||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Start-BitsTransfer -Source https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -Destination ggml-%model%.bin"
|
||||||
|
|
||||||
if %ERRORLEVEL% neq 0 (
|
if %ERRORLEVEL% neq 0 (
|
||||||
echo Failed to download ggml model %model%
|
echo Failed to download ggml model %model%
|
||||||
|
@ -22,7 +22,28 @@ function get_script_path() {
|
|||||||
models_path="$(get_script_path)"
|
models_path="$(get_script_path)"
|
||||||
|
|
||||||
# Whisper models
|
# Whisper models
|
||||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small.en-tdrz" "small" "medium.en" "medium" "large-v1" "large" )
|
models=(
|
||||||
|
"tiny.en"
|
||||||
|
"tiny"
|
||||||
|
"tiny-q5_1"
|
||||||
|
"tiny.en-q5_1"
|
||||||
|
"base.en"
|
||||||
|
"base"
|
||||||
|
"base-q5_1"
|
||||||
|
"base.en-q5_1"
|
||||||
|
"small.en"
|
||||||
|
"small.en-tdrz"
|
||||||
|
"small"
|
||||||
|
"small-q5_1"
|
||||||
|
"small.en-q5_1"
|
||||||
|
"medium"
|
||||||
|
"medium.en"
|
||||||
|
"medium-q5_0"
|
||||||
|
"medium.en-q5_0"
|
||||||
|
"large-v1"
|
||||||
|
"large"
|
||||||
|
"large-q5_0"
|
||||||
|
)
|
||||||
|
|
||||||
# list available models
|
# list available models
|
||||||
function list_models {
|
function list_models {
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Usage: ./generate-coreml-model.sh <model-name>
|
# Usage: ./generate-coreml-model.sh <model-name>
|
||||||
if [ $# -eq 0 ]
|
if [ $# -eq 0 ]; then
|
||||||
then
|
|
||||||
echo "No model name supplied"
|
echo "No model name supplied"
|
||||||
echo "Usage: ./generate-coreml-model.sh <model-name>"
|
echo "Usage for Whisper models: ./generate-coreml-model.sh <model-name>"
|
||||||
|
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
|
||||||
|
exit 1
|
||||||
|
elif [[ "$1" == "-h5" && $# != 3 ]]; then
|
||||||
|
echo "No model name and model path supplied for a HuggingFace model"
|
||||||
|
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -13,7 +17,14 @@ mname="$1"
|
|||||||
wd=$(dirname "$0")
|
wd=$(dirname "$0")
|
||||||
cd "$wd/../"
|
cd "$wd/../"
|
||||||
|
|
||||||
|
if [[ $mname == "-h5" ]]; then
|
||||||
|
mname="$2"
|
||||||
|
mpath="$3"
|
||||||
|
echo $mpath
|
||||||
|
python3 models/convert-h5-to-coreml.py --model-name $mname --model-path $mpath --encoder-only True
|
||||||
|
else
|
||||||
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
|
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
|
||||||
|
fi
|
||||||
|
|
||||||
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
|
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
|
||||||
rm -rf models/ggml-${mname}-encoder.mlmodelc
|
rm -rf models/ggml-${mname}-encoder.mlmodelc
|
||||||
|
2004
whisper.cpp
2004
whisper.cpp
File diff suppressed because it is too large
Load Diff
47
whisper.h
47
whisper.h
@ -96,37 +96,6 @@ extern "C" {
|
|||||||
void (*close)(void * ctx);
|
void (*close)(void * ctx);
|
||||||
} whisper_model_loader;
|
} whisper_model_loader;
|
||||||
|
|
||||||
// grammar element type
|
|
||||||
enum whisper_gretype {
|
|
||||||
// end of rule definition
|
|
||||||
WHISPER_GRETYPE_END = 0,
|
|
||||||
|
|
||||||
// start of alternate definition for rule
|
|
||||||
WHISPER_GRETYPE_ALT = 1,
|
|
||||||
|
|
||||||
// non-terminal element: reference to rule
|
|
||||||
WHISPER_GRETYPE_RULE_REF = 2,
|
|
||||||
|
|
||||||
// terminal element: character (code point)
|
|
||||||
WHISPER_GRETYPE_CHAR = 3,
|
|
||||||
|
|
||||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
|
||||||
WHISPER_GRETYPE_CHAR_NOT = 4,
|
|
||||||
|
|
||||||
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
|
||||||
// be an inclusive range ([a-z])
|
|
||||||
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
|
|
||||||
|
|
||||||
// modifies a preceding WHISPER_GRETYPE_CHAR or
|
|
||||||
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
|
||||||
WHISPER_GRETYPE_CHAR_ALT = 6,
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct whisper_grammar_element {
|
|
||||||
enum whisper_gretype type;
|
|
||||||
uint32_t value; // Unicode code point or rule ID
|
|
||||||
} whisper_grammar_element;
|
|
||||||
|
|
||||||
// Various functions for loading a ggml whisper model.
|
// Various functions for loading a ggml whisper model.
|
||||||
// Allocate (almost) all memory needed for the model.
|
// Allocate (almost) all memory needed for the model.
|
||||||
// Return NULL on failure
|
// Return NULL on failure
|
||||||
@ -365,6 +334,11 @@ extern "C" {
|
|||||||
// If it returns false, the computation is aborted
|
// If it returns false, the computation is aborted
|
||||||
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
||||||
|
|
||||||
|
// Abort callback
|
||||||
|
// If not NULL, called before ggml computation
|
||||||
|
// If it returns true, the computation is aborted
|
||||||
|
typedef bool (*whisper_abort_callback)(void * user_data);
|
||||||
|
|
||||||
// Logits filter callback
|
// Logits filter callback
|
||||||
// Can be used to modify the logits before sampling
|
// Can be used to modify the logits before sampling
|
||||||
// If not NULL, called after applying temperature to logits
|
// If not NULL, called after applying temperature to logits
|
||||||
@ -389,7 +363,6 @@ extern "C" {
|
|||||||
|
|
||||||
bool translate;
|
bool translate;
|
||||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||||
bool no_timestamps; // do not generate timestamps
|
|
||||||
bool single_segment; // force single segment output (useful for streaming)
|
bool single_segment; // force single segment output (useful for streaming)
|
||||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||||
bool print_progress; // print progress information
|
bool print_progress; // print progress information
|
||||||
@ -460,14 +433,13 @@ extern "C" {
|
|||||||
whisper_encoder_begin_callback encoder_begin_callback;
|
whisper_encoder_begin_callback encoder_begin_callback;
|
||||||
void * encoder_begin_callback_user_data;
|
void * encoder_begin_callback_user_data;
|
||||||
|
|
||||||
|
// called each time before ggml computation starts
|
||||||
|
whisper_abort_callback abort_callback;
|
||||||
|
void * abort_callback_user_data;
|
||||||
|
|
||||||
// called by each decoder to filter obtained logits
|
// called by each decoder to filter obtained logits
|
||||||
whisper_logits_filter_callback logits_filter_callback;
|
whisper_logits_filter_callback logits_filter_callback;
|
||||||
void * logits_filter_callback_user_data;
|
void * logits_filter_callback_user_data;
|
||||||
|
|
||||||
const whisper_grammar_element ** grammar_rules;
|
|
||||||
size_t n_grammar_rules;
|
|
||||||
size_t i_start_rule;
|
|
||||||
float grammar_penalty;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
|
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
|
||||||
@ -522,6 +494,7 @@ extern "C" {
|
|||||||
|
|
||||||
// Get whether the next segment is predicted as a speaker turn
|
// Get whether the next segment is predicted as a speaker turn
|
||||||
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
|
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
|
||||||
|
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
|
||||||
|
|
||||||
// Get the text of the specified segment
|
// Get the text of the specified segment
|
||||||
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
|
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
|
||||||
|
Reference in New Issue
Block a user