mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
Feature/java bindings2 (#944)
* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work. * added convenience methods to WhisperFullParams * Remove unused WhisperJavaParams
This commit is contained in:
parent
9b926844e3
commit
d7c936b44a
47
.github/workflows/build.yml
vendored
47
.github/workflows/build.yml
vendored
@ -125,8 +125,10 @@ jobs:
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
jnaPath: win32-x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
jnaPath: win32-x86-64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
@ -159,6 +161,12 @@ jobs:
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload dll
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ${{ matrix.jnaPath }}_whisper.dll
|
||||
path: build/bin/${{ matrix.build }}/whisper.dll
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
@ -363,3 +371,42 @@ jobs:
|
||||
run: |
|
||||
cd examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
||||
|
||||
java:
|
||||
needs: [ 'windows' ]
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
|
||||
- name: Install Java
|
||||
uses: actions/setup-java@v1
|
||||
with:
|
||||
java-version: 17
|
||||
|
||||
- name: Download Windows lib
|
||||
uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: win32-x86-64_whisper.dll
|
||||
path: bindings/java/build/generated/resources/main/win32-x86-64
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
models\download-ggml-model.cmd tiny.en
|
||||
cd bindings/java
|
||||
chmod +x ./gradlew
|
||||
./gradlew build
|
||||
|
||||
- name: Upload jar
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: whispercpp.jar
|
||||
path: bindings/java/build/libs/whispercpp-*.jar
|
||||
|
||||
# - name: Publish package
|
||||
# if: ${{ github.ref == 'refs/heads/master' }}
|
||||
# uses: gradle/gradle-build-action@v2
|
||||
# with:
|
||||
# arguments: publish
|
||||
# env:
|
||||
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
|
||||
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
|
||||
|
@ -1,50 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
|
||||
project(whisper_java VERSION 1.4.2)
|
||||
|
||||
# Set the target name and source file/s
|
||||
set(TARGET_NAME whisper_java)
|
||||
set(SOURCES src/main/cpp/whisper_java.cpp)
|
||||
|
||||
# include <whisper.h>
|
||||
include_directories(../../)
|
||||
|
||||
# Set the output directory for the DLL/shared library based on the platform as required by JNA
|
||||
if(WIN32)
|
||||
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)
|
||||
elseif(UNIX AND NOT APPLE)
|
||||
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)
|
||||
elseif(APPLE)
|
||||
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)
|
||||
endif()
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})
|
||||
|
||||
# Create the whisper_java library
|
||||
add_library(${TARGET_NAME} SHARED ${SOURCES})
|
||||
|
||||
# Link against ../../build/Release/whisper.dll (or so/dynlib)
|
||||
target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE whisper)
|
||||
|
||||
# Set the appropriate compiler flags for Windows, Linux, and macOS
|
||||
if(WIN32)
|
||||
target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)
|
||||
elseif(UNIX AND NOT APPLE)
|
||||
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
|
||||
elseif(APPLE)
|
||||
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)
|
||||
# add_definitions(-DWHISPER_SHARED)
|
||||
|
||||
# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA
|
||||
foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})
|
||||
string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)
|
||||
set_target_properties(${TARGET_NAME} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
|
||||
LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
|
||||
ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})
|
||||
endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)
|
@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o
|
||||
* Ubuntu on x86_64
|
||||
* Windows on x86_64
|
||||
|
||||
The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`.
|
||||
|
||||
There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested.
|
||||
|
||||
The most simple usage is as follows:
|
||||
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
|
||||
|
||||
```java
|
||||
import io.github.ggerganov.whispercpp.WhisperCpp;
|
||||
@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests
|
||||
git clone https://github.com/ggerganov/whisper.cpp.git
|
||||
cd whisper.cpp/bindings/java
|
||||
|
||||
mkdir build
|
||||
pushd build
|
||||
cmake ..
|
||||
cmake --build .
|
||||
popd
|
||||
|
||||
./gradlew build
|
||||
```
|
||||
|
||||
|
@ -22,6 +22,12 @@ sourceSets {
|
||||
}
|
||||
}
|
||||
|
||||
tasks.register('copyLibwhisperDynlib', Copy) {
|
||||
from '../../build'
|
||||
include 'libwhisper.dynlib'
|
||||
into 'build/generated/resources/main/darwin'
|
||||
}
|
||||
|
||||
tasks.register('copyLibwhisperSo', Copy) {
|
||||
from '../../build'
|
||||
include 'libwhisper.so'
|
||||
@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) {
|
||||
into 'build/generated/resources/main/windows-x86-64'
|
||||
}
|
||||
|
||||
tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll
|
||||
tasks.register('copyLibs') {
|
||||
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
|
||||
}
|
||||
|
||||
test {
|
||||
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
|
||||
|
@ -1,33 +0,0 @@
|
||||
#include <stdio.h>
|
||||
#include "whisper_java.h"
|
||||
|
||||
struct whisper_full_params default_params;
|
||||
struct whisper_context * whisper_ctx = nullptr;
|
||||
|
||||
struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {
|
||||
default_params = whisper_full_default_params(strategy);
|
||||
|
||||
// struct whisper_java_params result = {};
|
||||
// return result;
|
||||
return;
|
||||
}
|
||||
|
||||
void whisper_java_init_from_file(const char * path_model) {
|
||||
whisper_ctx = whisper_init_from_file(path_model);
|
||||
if (0 == default_params.n_threads) {
|
||||
whisper_java_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
}
|
||||
}
|
||||
|
||||
/** Delegates to whisper_full, but without having to pass `whisper_full_params` */
|
||||
int whisper_java_full(
|
||||
struct whisper_context * ctx,
|
||||
// struct whisper_java_params params,
|
||||
const float * samples,
|
||||
int n_samples) {
|
||||
return whisper_full(ctx, default_params, samples, n_samples);
|
||||
}
|
||||
|
||||
void whisper_java_free() {
|
||||
// free(default_params);
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
#define WHISPER_BUILD
|
||||
#include <whisper.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct whisper_java_params {
|
||||
};
|
||||
|
||||
WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);
|
||||
|
||||
WHISPER_API void whisper_java_init_from_file(const char * path_model);
|
||||
|
||||
WHISPER_API int whisper_java_full(
|
||||
struct whisper_context * ctx,
|
||||
// struct whisper_java_params params,
|
||||
const float * samples,
|
||||
int n_samples);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
@ -1,7 +1,8 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
|
||||
import java.io.File;
|
||||
@ -13,8 +14,9 @@ import java.io.IOException;
|
||||
*/
|
||||
public class WhisperCpp implements AutoCloseable {
|
||||
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
|
||||
private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;
|
||||
private Pointer ctx = null;
|
||||
private Pointer greedyPointer = null;
|
||||
private Pointer beamPointer = null;
|
||||
|
||||
public File modelDir() {
|
||||
String modelDirPath = System.getenv("XDG_CACHE_HOME");
|
||||
@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
||||
* @return a Pointer to the WhisperContext
|
||||
*/
|
||||
void initContext(String modelPath) throws FileNotFoundException {
|
||||
public void initContext(String modelPath) throws FileNotFoundException {
|
||||
if (ctx != null) {
|
||||
lib.whisper_free(ctx);
|
||||
}
|
||||
@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable {
|
||||
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
|
||||
}
|
||||
|
||||
javaLib.whisper_java_init_from_file(modelPath);
|
||||
ctx = lib.whisper_init_from_file(modelPath);
|
||||
|
||||
if (ctx == null) {
|
||||
@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.
|
||||
* `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.
|
||||
* Provides default params which can be used with `whisper_full()` etc.
|
||||
* Because this function allocates memory for the params, the caller must call either:
|
||||
* - call `whisper_free_params()`
|
||||
* - `Native.free(Pointer.nativeValue(pointer));`
|
||||
*
|
||||
* @param strategy - GREEDY
|
||||
*/
|
||||
public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {
|
||||
javaLib.whisper_java_default_params(strategy.ordinal());
|
||||
// return lib.whisper_full_default_params(strategy.value)
|
||||
}
|
||||
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
|
||||
Pointer pointer;
|
||||
|
||||
// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params
|
||||
// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {
|
||||
// return lib.whisper_full_default_params(strategy.value)
|
||||
// }
|
||||
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
|
||||
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
|
||||
if (greedyPointer == null) {
|
||||
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
|
||||
}
|
||||
pointer = greedyPointer;
|
||||
} else {
|
||||
if (beamPointer == null) {
|
||||
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
|
||||
}
|
||||
pointer = beamPointer;
|
||||
}
|
||||
|
||||
WhisperFullParams params = new WhisperFullParams(pointer);
|
||||
params.read();
|
||||
return params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
freeContext();
|
||||
freeParams();
|
||||
System.out.println("Whisper closed");
|
||||
}
|
||||
|
||||
@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable {
|
||||
}
|
||||
}
|
||||
|
||||
private void freeParams() {
|
||||
if (greedyPointer != null) {
|
||||
Native.free(Pointer.nativeValue(greedyPointer));
|
||||
greedyPointer = null;
|
||||
}
|
||||
if (beamPointer != null) {
|
||||
Native.free(Pointer.nativeValue(beamPointer));
|
||||
beamPointer = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*/
|
||||
public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {
|
||||
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {
|
||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
|
@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library {
|
||||
void whisper_print_timings(Pointer ctx);
|
||||
void whisper_reset_timings(Pointer ctx);
|
||||
|
||||
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
|
||||
// when `whisper_full_default_params()` tries to return a struct.
|
||||
// WhisperFullParams whisper_full_default_params(int strategy);
|
||||
|
||||
/**
|
||||
* Provides default params which can be used with `whisper_full()` etc.
|
||||
* Because this function allocates memory for the params, the caller must call either:
|
||||
* - call `whisper_free_params()`
|
||||
* - `Native.free(Pointer.nativeValue(pointer));`
|
||||
*
|
||||
* @param strategy - WhisperSamplingStrategy.value
|
||||
*/
|
||||
WhisperFullParams whisper_full_default_params(int strategy);
|
||||
Pointer whisper_full_default_params_by_ref(int strategy);
|
||||
|
||||
void whisper_free_params(Pointer params);
|
||||
|
||||
/**
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
|
@ -1,23 +0,0 @@
|
||||
package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Library;
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
||||
|
||||
interface WhisperJavaJnaLibrary extends Library {
|
||||
WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);
|
||||
|
||||
void whisper_java_default_params(int strategy);
|
||||
|
||||
void whisper_java_free();
|
||||
|
||||
void whisper_java_init_from_file(String modelPath);
|
||||
|
||||
/**
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*/
|
||||
int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);
|
||||
}
|
@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback {
|
||||
* @param user_data User data.
|
||||
* @return True if the computation should proceed, false otherwise.
|
||||
*/
|
||||
boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);
|
||||
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
|
||||
}
|
||||
|
@ -1,12 +1,9 @@
|
||||
package io.github.ggerganov.whispercpp.callbacks;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
|
||||
|
||||
import javax.security.auth.callback.Callback;
|
||||
|
||||
/**
|
||||
* Callback to filter logits.
|
||||
* Can be used to modify the logits before sampling.
|
||||
@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback {
|
||||
* @param logits The array of logits.
|
||||
* @param user_data User data.
|
||||
*/
|
||||
void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
|
||||
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
|
||||
}
|
||||
|
@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback {
|
||||
* @param n_new The number of newly generated text segments.
|
||||
* @param user_data User data.
|
||||
*/
|
||||
void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);
|
||||
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
|
||||
}
|
||||
|
@ -1,11 +1,10 @@
|
||||
package io.github.ggerganov.whispercpp.callbacks;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.WhisperContext;
|
||||
import io.github.ggerganov.whispercpp.model.WhisperState;
|
||||
|
||||
import javax.security.auth.callback.Callback;
|
||||
|
||||
/**
|
||||
* Callback for progress updates.
|
||||
*/
|
||||
@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback {
|
||||
* @param progress The progress value.
|
||||
* @param user_data User data.
|
||||
*/
|
||||
void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);
|
||||
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
|
||||
}
|
||||
|
@ -0,0 +1,19 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.Structure;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class BeamSearchParams extends Structure {
|
||||
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
|
||||
public int beam_size;
|
||||
|
||||
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
|
||||
public float patience;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("beam_size", "patience");
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.IntegerType;
|
||||
|
||||
import java.util.function.BooleanSupplier;
|
||||
|
||||
public class CBool extends IntegerType implements BooleanSupplier {
|
||||
public static final int SIZE = 1;
|
||||
public static final CBool FALSE = new CBool(0);
|
||||
public static final CBool TRUE = new CBool(1);
|
||||
|
||||
|
||||
public CBool() {
|
||||
this(0);
|
||||
}
|
||||
|
||||
public CBool(long value) {
|
||||
super(SIZE, value, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean getAsBoolean() {
|
||||
return intValue() == 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return intValue() == 1 ? "true" : "false";
|
||||
}
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.Structure;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class GreedyParams extends Structure {
|
||||
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
|
||||
public int best_of;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Collections.singletonList("best_of");
|
||||
}
|
||||
}
|
@ -1,13 +1,14 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.Callback;
|
||||
import com.sun.jna.Pointer;
|
||||
import com.sun.jna.Structure;
|
||||
import com.sun.jna.*;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
||||
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Parameters for the whisper_full() function.
|
||||
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
|
||||
@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
||||
*/
|
||||
public class WhisperFullParams extends Structure {
|
||||
|
||||
public WhisperFullParams(Pointer p) {
|
||||
super(p);
|
||||
// super(p, ALIGN_MSVC);
|
||||
// super(p, ALIGN_GNUC);
|
||||
}
|
||||
|
||||
/** Sampling strategy for whisper_full() function. */
|
||||
public int strategy;
|
||||
|
||||
/** Number of threads. */
|
||||
/** Number of threads. (default = 4) */
|
||||
public int n_threads;
|
||||
|
||||
/** Maximum tokens to use from past text as a prompt for the decoder. */
|
||||
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
|
||||
public int n_max_text_ctx;
|
||||
|
||||
/** Start offset in milliseconds. */
|
||||
/** Start offset in milliseconds. (default = 0) */
|
||||
public int offset_ms;
|
||||
|
||||
/** Audio duration to process in milliseconds. */
|
||||
/** Audio duration to process in milliseconds. (default = 0) */
|
||||
public int duration_ms;
|
||||
|
||||
/** Translate flag. */
|
||||
public boolean translate;
|
||||
/** Translate flag. (default = false) */
|
||||
public CBool translate;
|
||||
|
||||
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */
|
||||
public boolean no_context;
|
||||
/** The compliment of translateMode() */
|
||||
public void transcribeMode() {
|
||||
translate = CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Flag to force single segment output (useful for streaming). */
|
||||
public boolean single_segment;
|
||||
/** The compliment of transcribeMode() */
|
||||
public void translateMode() {
|
||||
translate = CBool.TRUE;
|
||||
}
|
||||
|
||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). */
|
||||
public boolean print_special;
|
||||
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
|
||||
public CBool no_context;
|
||||
|
||||
/** Flag to print progress information. */
|
||||
public boolean print_progress;
|
||||
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
|
||||
public void enableContext(boolean enable) {
|
||||
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
||||
}
|
||||
|
||||
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). */
|
||||
public boolean print_realtime;
|
||||
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||
public CBool single_segment;
|
||||
|
||||
/** Flag to print timestamps for each text segment when printing realtime. */
|
||||
public boolean print_timestamps;
|
||||
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||
public void singleSegment(boolean single) {
|
||||
single_segment = single ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** [EXPERIMENTAL] Flag to enable token-level timestamps. */
|
||||
public boolean token_timestamps;
|
||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
||||
public CBool print_special;
|
||||
|
||||
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */
|
||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
||||
public void printSpecial(boolean enable) {
|
||||
print_special = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Flag to print progress information. (default = true) */
|
||||
public CBool print_progress;
|
||||
|
||||
/** Flag to print progress information. (default = true) */
|
||||
public void printProgress(boolean enable) {
|
||||
print_progress = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
|
||||
public CBool print_realtime;
|
||||
|
||||
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
|
||||
public void printRealtime(boolean enable) {
|
||||
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
|
||||
public CBool print_timestamps;
|
||||
|
||||
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
|
||||
public void printTimestamps(boolean enable) {
|
||||
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
|
||||
public CBool token_timestamps;
|
||||
|
||||
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
|
||||
public void tokenTimestamps(boolean enable) {
|
||||
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
|
||||
public float thold_pt;
|
||||
|
||||
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
|
||||
public float thold_ptsum;
|
||||
|
||||
/** Maximum segment length in characters. */
|
||||
/** Maximum segment length in characters. (default = 0) */
|
||||
public int max_len;
|
||||
|
||||
/** Flag to split on word rather than on token (when used with max_len). */
|
||||
public boolean split_on_word;
|
||||
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
|
||||
public CBool split_on_word;
|
||||
|
||||
/** Maximum tokens per segment (0 = no limit). */
|
||||
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
|
||||
public void splitOnWord(boolean enable) {
|
||||
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Maximum tokens per segment (0, default = no limit) */
|
||||
public int max_tokens;
|
||||
|
||||
/** Flag to speed up the audio by 2x using Phase Vocoder. */
|
||||
public boolean speed_up;
|
||||
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
||||
public CBool speed_up;
|
||||
|
||||
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
|
||||
public void speedUp(boolean enable) {
|
||||
speed_up = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Overwrite the audio context size (0 = use default). */
|
||||
public int audio_ctx;
|
||||
@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure {
|
||||
* These are prepended to any existing text context from a previous call. */
|
||||
public String initial_prompt;
|
||||
|
||||
/** Prompt tokens. */
|
||||
/** Prompt tokens. (int*) */
|
||||
public Pointer prompt_tokens;
|
||||
|
||||
public void setPromptTokens(int[] tokens) {
|
||||
Memory mem = new Memory(tokens.length * 4L);
|
||||
mem.write(0, tokens, 0, tokens.length);
|
||||
prompt_tokens = mem;
|
||||
}
|
||||
|
||||
/** Number of prompt tokens. */
|
||||
public int prompt_n_tokens;
|
||||
|
||||
@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure {
|
||||
public String language;
|
||||
|
||||
/** Flag to indicate whether to detect language automatically. */
|
||||
public boolean detect_language;
|
||||
public CBool detect_language;
|
||||
|
||||
/** Common decoding parameters. */
|
||||
/** Flag to indicate whether to detect language automatically. */
|
||||
public void detectLanguage(boolean enable) {
|
||||
detect_language = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
// Common decoding parameters.
|
||||
|
||||
/** Flag to suppress blank tokens. */
|
||||
public boolean suppress_blank;
|
||||
public CBool suppress_blank;
|
||||
|
||||
public void suppressBlanks(boolean enable) {
|
||||
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public boolean suppress_non_speech_tokens;
|
||||
public CBool suppress_non_speech_tokens;
|
||||
|
||||
/** Flag to suppress non-speech tokens. */
|
||||
public void suppressNonSpeechTokens(boolean enable) {
|
||||
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
|
||||
}
|
||||
|
||||
/** Initial decoding temperature. */
|
||||
public float temperature;
|
||||
@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure {
|
||||
/** Length penalty. */
|
||||
public float length_penalty;
|
||||
|
||||
/** Fallback parameters. */
|
||||
// Fallback parameters.
|
||||
|
||||
/** Temperature increment. */
|
||||
public float temperature_inc;
|
||||
@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure {
|
||||
/** No speech threshold. */
|
||||
public float no_speech_thold;
|
||||
|
||||
class GreedyParams extends Structure {
|
||||
/** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */
|
||||
public int best_of;
|
||||
}
|
||||
|
||||
/** Greedy decoding parameters. */
|
||||
public GreedyParams greedy;
|
||||
|
||||
class BeamSearchParams extends Structure {
|
||||
/** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */
|
||||
int beam_size;
|
||||
|
||||
/** ref: https://arxiv.org/pdf/2204.05424.pdf */
|
||||
float patience;
|
||||
}
|
||||
|
||||
/**
|
||||
* Beam search decoding parameters.
|
||||
*/
|
||||
public BeamSearchParams beam_search;
|
||||
|
||||
public void setBestOf(int bestOf) {
|
||||
if (greedy == null) {
|
||||
greedy = new GreedyParams();
|
||||
}
|
||||
greedy.best_of = bestOf;
|
||||
}
|
||||
|
||||
public void setBeamSize(int beamSize) {
|
||||
if (beam_search == null) {
|
||||
beam_search = new BeamSearchParams();
|
||||
}
|
||||
beam_search.beam_size = beamSize;
|
||||
}
|
||||
|
||||
public void setBeamSizeAndPatience(int beamSize, float patience) {
|
||||
if (beam_search == null) {
|
||||
beam_search = new BeamSearchParams();
|
||||
}
|
||||
beam_search.beam_size = beamSize;
|
||||
beam_search.patience = patience;
|
||||
}
|
||||
|
||||
/**
|
||||
* Callback for every newly generated text segment.
|
||||
* WhisperNewSegmentCallback
|
||||
*/
|
||||
public WhisperNewSegmentCallback new_segment_callback;
|
||||
public Pointer new_segment_callback;
|
||||
|
||||
/**
|
||||
* User data for the new_segment_callback.
|
||||
@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure {
|
||||
|
||||
/**
|
||||
* Callback on each progress update.
|
||||
* WhisperProgressCallback
|
||||
*/
|
||||
public WhisperProgressCallback progress_callback;
|
||||
public Pointer progress_callback;
|
||||
|
||||
/**
|
||||
* User data for the progress_callback.
|
||||
@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure {
|
||||
|
||||
/**
|
||||
* Callback each time before the encoder starts.
|
||||
* WhisperEncoderBeginCallback
|
||||
*/
|
||||
public WhisperEncoderBeginCallback encoder_begin_callback;
|
||||
public Pointer encoder_begin_callback;
|
||||
|
||||
/**
|
||||
* User data for the encoder_begin_callback.
|
||||
@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure {
|
||||
|
||||
/**
|
||||
* Callback by each decoder to filter obtained logits.
|
||||
* WhisperLogitsFilterCallback
|
||||
*/
|
||||
public WhisperLogitsFilterCallback logits_filter_callback;
|
||||
public Pointer logits_filter_callback;
|
||||
|
||||
/**
|
||||
* User data for the logits_filter_callback.
|
||||
*/
|
||||
public Pointer logits_filter_callback_user_data;
|
||||
}
|
||||
|
||||
|
||||
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
|
||||
new_segment_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
public void setProgressCallback(WhisperProgressCallback callback) {
|
||||
progress_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
|
||||
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
|
||||
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||
"no_context", "single_segment",
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||
"initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
|
||||
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||
"logits_filter_callback", "logits_filter_callback_user_data");
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +0,0 @@
|
||||
package io.github.ggerganov.whispercpp.params;
|
||||
|
||||
import com.sun.jna.Structure;
|
||||
|
||||
public class WhisperJavaParams extends Structure {
|
||||
|
||||
}
|
@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
|
||||
import io.github.ggerganov.whispercpp.params.CBool;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@ -19,11 +20,11 @@ class WhisperCppTest {
|
||||
static void init() throws FileNotFoundException {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
String modelName = "base.en";
|
||||
String modelName = "../../models/ggml-tiny.en.bin";
|
||||
try {
|
||||
whisper.initContext(modelName);
|
||||
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
modelInitialised = true;
|
||||
} catch (FileNotFoundException ex) {
|
||||
System.out.println("Model " + modelName + " not found");
|
||||
@ -31,11 +32,30 @@ class WhisperCppTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetDefaultJavaParams() {
|
||||
void testGetDefaultFullParams_BeamSearch() {
|
||||
// When
|
||||
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
|
||||
// Then if it doesn't throw we've connected to whisper.cpp
|
||||
// Then
|
||||
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
|
||||
assertNotEquals(0, params.n_threads);
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertFalse(params.translate);
|
||||
assertEquals(0.01f, params.thold_pt);
|
||||
assertEquals(2, params.beam_search.beam_size);
|
||||
assertEquals(-1.0f, params.beam_search.patience);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetDefaultFullParams_Greedy() {
|
||||
// When
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
// Then
|
||||
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
|
||||
assertNotEquals(0, params.n_threads);
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertEquals(2, params.greedy.best_of);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -52,6 +72,13 @@ class WhisperCppTest {
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
// params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
|
||||
try {
|
||||
audioInputStream.read(b);
|
||||
|
||||
@ -61,13 +88,13 @@ class WhisperCppTest {
|
||||
}
|
||||
|
||||
// When
|
||||
String result = whisper.fullTranscribe(/*params,*/ floats);
|
||||
String result = whisper.fullTranscribe(params, floats);
|
||||
|
||||
// Then
|
||||
System.out.println(result);
|
||||
assertEquals("And so my fellow Americans, ask not what your country can do for you, " +
|
||||
System.err.println(result);
|
||||
assertEquals("And so my fellow Americans ask not what your country can do for you " +
|
||||
"ask what you can do for your country.",
|
||||
result);
|
||||
result.replace(",", ""));
|
||||
} finally {
|
||||
audioInputStream.close();
|
||||
}
|
||||
|
14
whisper.cpp
14
whisper.cpp
@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
void whisper_free_params(struct whisper_full_params * params) {
|
||||
if (params) {
|
||||
delete params;
|
||||
}
|
||||
}
|
||||
|
||||
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
|
||||
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
|
||||
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
|
||||
@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
|
||||
struct whisper_full_params params = whisper_full_default_params(strategy);
|
||||
|
||||
struct whisper_full_params* result = new whisper_full_params();
|
||||
*result = params;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
|
||||
struct whisper_full_params result = {
|
||||
/*.strategy =*/ strategy,
|
||||
|
@ -113,6 +113,7 @@ extern "C" {
|
||||
// Frees all allocated memory
|
||||
WHISPER_API void whisper_free (struct whisper_context * ctx);
|
||||
WHISPER_API void whisper_free_state(struct whisper_state * state);
|
||||
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
|
||||
|
||||
// Convert RAW PCM audio to log mel spectrogram.
|
||||
// The resulting spectrogram is stored inside the default state of the provided whisper context.
|
||||
@ -409,6 +410,8 @@ extern "C" {
|
||||
void * logits_filter_callback_user_data;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
|
||||
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
|
||||
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
|
Loading…
Reference in New Issue
Block a user