Feature/java bindings2 (#944)

* Java needs to call `whisper_full_default_params_by_ref()`, returning struct by val does not seem to work.
* added convenience methods to WhisperFullParams
* Remove unused WhisperJavaParams
This commit is contained in:
Nicholas Albion 2023-05-29 09:38:58 +10:00 committed by GitHub
parent 9b926844e3
commit d7c936b44a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 419 additions and 242 deletions

View File

@ -125,8 +125,10 @@ jobs:
include:
- arch: Win32
s2arc: x86
jnaPath: win32-x86
- arch: x64
s2arc: x64
jnaPath: win32-x86-64
- sdl2: ON
s2ver: 2.26.0
@ -159,6 +161,12 @@ jobs:
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload dll
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
@ -363,3 +371,42 @@ jobs:
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon
java:
needs: [ 'windows' ]
runs-on: windows-latest
steps:
- uses: actions/checkout@v1
- name: Install Java
uses: actions/setup-java@v1
with:
java-version: 17
- name: Download Windows lib
uses: actions/download-artifact@v3
with:
name: win32-x86-64_whisper.dll
path: bindings/java/build/generated/resources/main/win32-x86-64
- name: Build
run: |
models\download-ggml-model.cmd tiny.en
cd bindings/java
chmod +x ./gradlew
./gradlew build
- name: Upload jar
uses: actions/upload-artifact@v3
with:
name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar
# - name: Publish package
# if: ${{ github.ref == 'refs/heads/master' }}
# uses: gradle/gradle-build-action@v2
# with:
# arguments: publish
# env:
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}

View File

@ -1,50 +0,0 @@
cmake_minimum_required(VERSION 3.10)
project(whisper_java VERSION 1.4.2)
# Set the target name and source file/s
set(TARGET_NAME whisper_java)
set(SOURCES src/main/cpp/whisper_java.cpp)
# include <whisper.h>
include_directories(../../)
# Set the output directory for the DLL/shared library based on the platform as required by JNA
if(WIN32)
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64)
elseif(UNIX AND NOT APPLE)
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64)
elseif(APPLE)
set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64)
endif()
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR})
# Create the whisper_java library
add_library(${TARGET_NAME} SHARED ${SOURCES})
# Link against ../../build/Release/whisper.dll (or so/dynlib)
target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE})
target_link_libraries(${TARGET_NAME} PRIVATE whisper)
# Set the appropriate compiler flags for Windows, Linux, and macOS
if(WIN32)
target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS)
elseif(UNIX AND NOT APPLE)
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
elseif(APPLE)
target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra)
endif()
target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED)
# add_definitions(-DWHISPER_SHARED)
# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA
foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES})
string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG)
set_target_properties(${TARGET_NAME} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}
ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR})
endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES)

View File

@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o
* Ubuntu on x86_64
* Windows on x86_64
The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`.
There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested.
The most simple usage is as follows:
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
```java
import io.github.ggerganov.whispercpp.WhisperCpp;
@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java
mkdir build
pushd build
cmake ..
cmake --build .
popd
./gradlew build
```

View File

@ -22,6 +22,12 @@ sourceSets {
}
}
tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build'
include 'libwhisper.dynlib'
into 'build/generated/resources/main/darwin'
}
tasks.register('copyLibwhisperSo', Copy) {
from '../../build'
include 'libwhisper.so'
@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) {
into 'build/generated/resources/main/windows-x86-64'
}
tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll
tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
}
test {
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath

View File

@ -1,33 +0,0 @@
#include <stdio.h>
#include "whisper_java.h"
struct whisper_full_params default_params;
struct whisper_context * whisper_ctx = nullptr;
struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) {
default_params = whisper_full_default_params(strategy);
// struct whisper_java_params result = {};
// return result;
return;
}
void whisper_java_init_from_file(const char * path_model) {
whisper_ctx = whisper_init_from_file(path_model);
if (0 == default_params.n_threads) {
whisper_java_default_params(WHISPER_SAMPLING_GREEDY);
}
}
/** Delegates to whisper_full, but without having to pass `whisper_full_params` */
int whisper_java_full(
struct whisper_context * ctx,
// struct whisper_java_params params,
const float * samples,
int n_samples) {
return whisper_full(ctx, default_params, samples, n_samples);
}
void whisper_java_free() {
// free(default_params);
}

View File

@ -1,24 +0,0 @@
#define WHISPER_BUILD
#include <whisper.h>
#ifdef __cplusplus
extern "C" {
#endif
struct whisper_java_params {
};
WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy);
WHISPER_API void whisper_java_init_from_file(const char * path_model);
WHISPER_API int whisper_java_full(
struct whisper_context * ctx,
// struct whisper_java_params params,
const float * samples,
int n_samples);
#ifdef __cplusplus
}
#endif

View File

@ -1,7 +1,8 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import java.io.File;
@ -13,8 +14,9 @@ import java.io.IOException;
*/
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable {
/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
* @return a Pointer to the WhisperContext
*/
void initContext(String modelPath) throws FileNotFoundException {
public void initContext(String modelPath) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable {
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}
javaLib.whisper_java_init_from_file(modelPath);
ctx = lib.whisper_init_from_file(modelPath);
if (ctx == null) {
@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable {
}
/**
* Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything.
* `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience.
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - GREEDY
*/
public void getDefaultJavaParams(WhisperSamplingStrategy strategy) {
javaLib.whisper_java_default_params(strategy.ordinal());
// return lib.whisper_full_default_params(strategy.value)
}
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
Pointer pointer;
// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params
// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams {
// return lib.whisper_full_default_params(strategy.value)
// }
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
}
WhisperFullParams params = new WhisperFullParams(pointer);
params.read();
return params;
}
@Override
public void close() {
freeContext();
freeParams();
System.out.println("Whisper closed");
}
@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable {
}
}
private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
}
}
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException {
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) {
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}

View File

@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library {
void whisper_print_timings(Pointer ctx);
void whisper_reset_timings(Pointer ctx);
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
// when `whisper_full_default_params()` tries to return a struct.
// WhisperFullParams whisper_full_default_params(int strategy);
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - WhisperSamplingStrategy.value
*/
WhisperFullParams whisper_full_default_params(int strategy);
Pointer whisper_full_default_params_by_ref(int strategy);
void whisper_free_params(Pointer params);
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text

View File

@ -1,23 +0,0 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
interface WhisperJavaJnaLibrary extends Library {
WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class);
void whisper_java_default_params(int strategy);
void whisper_java_free();
void whisper_java_init_from_file(String modelPath);
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples);
}

View File

@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback {
* @param user_data User data.
* @return True if the computation should proceed, false otherwise.
*/
boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data);
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
}

View File

@ -1,12 +1,9 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import javax.security.auth.callback.Callback;
/**
* Callback to filter logits.
* Can be used to modify the logits before sampling.
@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback {
* @param logits The array of logits.
* @param user_data User data.
*/
void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
}

View File

@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback {
* @param n_new The number of newly generated text segments.
* @param user_data User data.
*/
void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data);
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
}

View File

@ -1,11 +1,10 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
import javax.security.auth.callback.Callback;
/**
* Callback for progress updates.
*/
@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback {
* @param progress The progress value.
* @param user_data User data.
*/
void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data);
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
}

View File

@ -0,0 +1,19 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
public class BeamSearchParams extends Structure {
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
public int beam_size;
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
public float patience;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("beam_size", "patience");
}
}

View File

@ -0,0 +1,30 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.IntegerType;
import java.util.function.BooleanSupplier;
public class CBool extends IntegerType implements BooleanSupplier {
public static final int SIZE = 1;
public static final CBool FALSE = new CBool(0);
public static final CBool TRUE = new CBool(1);
public CBool() {
this(0);
}
public CBool(long value) {
super(SIZE, value, true);
}
@Override
public boolean getAsBoolean() {
return intValue() == 1;
}
@Override
public String toString() {
return intValue() == 1 ? "true" : "false";
}
}

View File

@ -0,0 +1,16 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Collections;
import java.util.List;
public class GreedyParams extends Structure {
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
public int best_of;
@Override
protected List<String> getFieldOrder() {
return Collections.singletonList("best_of");
}
}

View File

@ -1,13 +1,14 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import com.sun.jna.Structure;
import com.sun.jna.*;
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
import java.util.Arrays;
import java.util.List;
/**
* Parameters for the whisper_full() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
*/
public class WhisperFullParams extends Structure {
public WhisperFullParams(Pointer p) {
super(p);
// super(p, ALIGN_MSVC);
// super(p, ALIGN_GNUC);
}
/** Sampling strategy for whisper_full() function. */
public int strategy;
/** Number of threads. */
/** Number of threads. (default = 4) */
public int n_threads;
/** Maximum tokens to use from past text as a prompt for the decoder. */
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
public int n_max_text_ctx;
/** Start offset in milliseconds. */
/** Start offset in milliseconds. (default = 0) */
public int offset_ms;
/** Audio duration to process in milliseconds. */
/** Audio duration to process in milliseconds. (default = 0) */
public int duration_ms;
/** Translate flag. */
public boolean translate;
/** Translate flag. (default = false) */
public CBool translate;
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */
public boolean no_context;
/** The compliment of translateMode() */
public void transcribeMode() {
translate = CBool.FALSE;
}
/** Flag to force single segment output (useful for streaming). */
public boolean single_segment;
/** The compliment of transcribeMode() */
public void translateMode() {
translate = CBool.TRUE;
}
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). */
public boolean print_special;
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public CBool no_context;
/** Flag to print progress information. */
public boolean print_progress;
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public void enableContext(boolean enable) {
no_context = enable ? CBool.FALSE : CBool.TRUE;
}
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). */
public boolean print_realtime;
/** Flag to force single segment output (useful for streaming). (default = false) */
public CBool single_segment;
/** Flag to print timestamps for each text segment when printing realtime. */
public boolean print_timestamps;
/** Flag to force single segment output (useful for streaming). (default = false) */
public void singleSegment(boolean single) {
single_segment = single ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Flag to enable token-level timestamps. */
public boolean token_timestamps;
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public CBool print_special;
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public void printSpecial(boolean enable) {
print_special = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print progress information. (default = true) */
public CBool print_progress;
/** Flag to print progress information. (default = true) */
public void printProgress(boolean enable) {
print_progress = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public CBool print_realtime;
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public void printRealtime(boolean enable) {
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public CBool print_timestamps;
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public void printTimestamps(boolean enable) {
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public CBool token_timestamps;
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public void tokenTimestamps(boolean enable) {
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
public float thold_pt;
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
public float thold_ptsum;
/** Maximum segment length in characters. */
/** Maximum segment length in characters. (default = 0) */
public int max_len;
/** Flag to split on word rather than on token (when used with max_len). */
public boolean split_on_word;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public CBool split_on_word;
/** Maximum tokens per segment (0 = no limit). */
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public void splitOnWord(boolean enable) {
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
}
/** Maximum tokens per segment (0, default = no limit) */
public int max_tokens;
/** Flag to speed up the audio by 2x using Phase Vocoder. */
public boolean speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public CBool speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public void speedUp(boolean enable) {
speed_up = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */
public int audio_ctx;
@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure {
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Prompt tokens. */
/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
public void setPromptTokens(int[] tokens) {
Memory mem = new Memory(tokens.length * 4L);
mem.write(0, tokens, 0, tokens.length);
prompt_tokens = mem;
}
/** Number of prompt tokens. */
public int prompt_n_tokens;
@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure {
public String language;
/** Flag to indicate whether to detect language automatically. */
public boolean detect_language;
public CBool detect_language;
/** Common decoding parameters. */
/** Flag to indicate whether to detect language automatically. */
public void detectLanguage(boolean enable) {
detect_language = enable ? CBool.TRUE : CBool.FALSE;
}
// Common decoding parameters.
/** Flag to suppress blank tokens. */
public boolean suppress_blank;
public CBool suppress_blank;
public void suppressBlanks(boolean enable) {
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to suppress non-speech tokens. */
public boolean suppress_non_speech_tokens;
public CBool suppress_non_speech_tokens;
/** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) {
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
}
/** Initial decoding temperature. */
public float temperature;
@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure {
/** Length penalty. */
public float length_penalty;
/** Fallback parameters. */
// Fallback parameters.
/** Temperature increment. */
public float temperature_inc;
@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure {
/** No speech threshold. */
public float no_speech_thold;
class GreedyParams extends Structure {
/** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */
public int best_of;
}
/** Greedy decoding parameters. */
public GreedyParams greedy;
class BeamSearchParams extends Structure {
/** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */
int beam_size;
/** ref: https://arxiv.org/pdf/2204.05424.pdf */
float patience;
}
/**
* Beam search decoding parameters.
*/
public BeamSearchParams beam_search;
public void setBestOf(int bestOf) {
if (greedy == null) {
greedy = new GreedyParams();
}
greedy.best_of = bestOf;
}
public void setBeamSize(int beamSize) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
}
public void setBeamSizeAndPatience(int beamSize, float patience) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
beam_search.patience = patience;
}
/**
* Callback for every newly generated text segment.
* WhisperNewSegmentCallback
*/
public WhisperNewSegmentCallback new_segment_callback;
public Pointer new_segment_callback;
/**
* User data for the new_segment_callback.
@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure {
/**
* Callback on each progress update.
* WhisperProgressCallback
*/
public WhisperProgressCallback progress_callback;
public Pointer progress_callback;
/**
* User data for the progress_callback.
@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure {
/**
* Callback each time before the encoder starts.
* WhisperEncoderBeginCallback
*/
public WhisperEncoderBeginCallback encoder_begin_callback;
public Pointer encoder_begin_callback;
/**
* User data for the encoder_begin_callback.
@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure {
/**
* Callback by each decoder to filter obtained logits.
* WhisperLogitsFilterCallback
*/
public WhisperLogitsFilterCallback logits_filter_callback;
public Pointer logits_filter_callback;
/**
* User data for the logits_filter_callback.
*/
public Pointer logits_filter_callback_user_data;
}
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
new_segment_callback = CallbackReference.getFunctionPointer(callback);
}
public void setProgressCallback(WhisperProgressCallback callback) {
progress_callback = CallbackReference.getFunctionPointer(callback);
}
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
}
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data");
}
}

View File

@ -1,7 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
public class WhisperJavaParams extends Structure {
}

View File

@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ggerganov.whispercpp.params.WhisperJavaParams;
import io.github.ggerganov.whispercpp.params.CBool;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
@ -19,11 +20,11 @@ class WhisperCppTest {
static void init() throws FileNotFoundException {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
String modelName = "base.en";
String modelName = "../../models/ggml-tiny.en.bin";
try {
whisper.initContext(modelName);
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
modelInitialised = true;
} catch (FileNotFoundException ex) {
System.out.println("Model " + modelName + " not found");
@ -31,11 +32,30 @@ class WhisperCppTest {
}
@Test
void testGetDefaultJavaParams() {
void testGetDefaultFullParams_BeamSearch() {
// When
whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// Then if it doesn't throw we've connected to whisper.cpp
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertFalse(params.translate);
assertEquals(0.01f, params.thold_pt);
assertEquals(2, params.beam_search.beam_size);
assertEquals(-1.0f, params.beam_search.patience);
}
@Test
void testGetDefaultFullParams_Greedy() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertEquals(2, params.greedy.best_of);
}
@Test
@ -52,6 +72,13 @@ class WhisperCppTest {
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
// params.initial_prompt = "and so my fellow Americans um, like";
try {
audioInputStream.read(b);
@ -61,13 +88,13 @@ class WhisperCppTest {
}
// When
String result = whisper.fullTranscribe(/*params,*/ floats);
String result = whisper.fullTranscribe(params, floats);
// Then
System.out.println(result);
assertEquals("And so my fellow Americans, ask not what your country can do for you, " +
System.err.println(result);
assertEquals("And so my fellow Americans ask not what your country can do for you " +
"ask what you can do for your country.",
result);
result.replace(",", ""));
} finally {
audioInputStream.close();
}

View File

@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) {
}
}
void whisper_free_params(struct whisper_full_params * params) {
if (params) {
delete params;
}
}
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) {
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
struct whisper_full_params params = whisper_full_default_params(strategy);
struct whisper_full_params* result = new whisper_full_params();
*result = params;
return result;
}
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
/*.strategy =*/ strategy,

View File

@ -113,6 +113,7 @@ extern "C" {
// Frees all allocated memory
WHISPER_API void whisper_free (struct whisper_context * ctx);
WHISPER_API void whisper_free_state(struct whisper_state * state);
WHISPER_API void whisper_free_params(struct whisper_full_params * params);
// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
@ -409,6 +410,8 @@ extern "C" {
void * logits_filter_callback_user_data;
};
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy);
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text