mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-26 00:29:21 +01:00
ggml : add abort_callback for cpu backend (ggml/725)
* a way to use abort_callback with the cpu backend * whisper update
This commit is contained in:
parent
aa8a75e287
commit
f75e1197f1
@ -653,6 +653,9 @@ struct ggml_backend_cpu_context {
|
|||||||
int n_threads;
|
int n_threads;
|
||||||
void * work_data;
|
void * work_data;
|
||||||
size_t work_size;
|
size_t work_size;
|
||||||
|
|
||||||
|
ggml_abort_callback abort_callback;
|
||||||
|
void * abort_callback_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
|
GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
|
||||||
@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
|
|||||||
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
|
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
|
||||||
|
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
||||||
|
|
||||||
return cpu_plan;
|
return cpu_plan;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
|
|||||||
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
|
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
|
||||||
cpu_ctx->work_size = cplan.work_size;
|
cpu_ctx->work_size = cplan.work_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
cplan.work_data = cpu_ctx->work_data;
|
cplan.work_data = cpu_ctx->work_data;
|
||||||
|
|
||||||
|
cplan.abort_callback = cpu_ctx->abort_callback;
|
||||||
|
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
||||||
|
|
||||||
ggml_graph_compute(cgraph, &cplan);
|
ggml_graph_compute(cgraph, &cplan);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = {
|
|||||||
ggml_backend_t ggml_backend_cpu_init(void) {
|
ggml_backend_t ggml_backend_cpu_init(void) {
|
||||||
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
|
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
|
||||||
|
|
||||||
ctx->n_threads = GGML_DEFAULT_N_THREADS;
|
ctx->n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
ctx->work_data = NULL;
|
ctx->work_data = NULL;
|
||||||
ctx->work_size = 0;
|
ctx->work_size = 0;
|
||||||
|
ctx->abort_callback = NULL;
|
||||||
|
ctx->abort_callback_data = NULL;
|
||||||
|
|
||||||
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
|
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
|
||||||
|
|
||||||
@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
|
|||||||
ctx->n_threads = n_threads;
|
ctx->n_threads = n_threads;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
|
||||||
|
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
||||||
|
|
||||||
|
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
||||||
|
ctx->abort_callback = abort_callback;
|
||||||
|
ctx->abort_callback_data = abort_callback_data;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
|
GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
|
||||||
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
|
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
|
||||||
}
|
}
|
||||||
|
@ -83,8 +83,9 @@ extern "C" {
|
|||||||
|
|
||||||
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
|
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
|
||||||
|
|
||||||
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
|
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
|
||||||
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
|
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
|
||||||
|
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
|
||||||
// Create a backend buffer from an existing pointer
|
// Create a backend buffer from an existing pointer
|
||||||
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
||||||
|
2
ggml.c
2
ggml.c
@ -16560,7 +16560,7 @@ struct ggml_compute_state_shared {
|
|||||||
atomic_int node_n; // active graph node
|
atomic_int node_n; // active graph node
|
||||||
atomic_int node_task; // active graph node task phase
|
atomic_int node_task; // active graph node task phase
|
||||||
|
|
||||||
bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
|
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
|
||||||
void * abort_callback_data;
|
void * abort_callback_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
9
ggml.h
9
ggml.h
@ -567,6 +567,11 @@ extern "C" {
|
|||||||
|
|
||||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||||
|
|
||||||
|
// Abort callback
|
||||||
|
// If not NULL, called before ggml computation
|
||||||
|
// If it returns true, the computation is aborted
|
||||||
|
typedef bool (*ggml_abort_callback)(void * data);
|
||||||
|
|
||||||
// the compute plan that needs to be prepared for ggml_graph_compute()
|
// the compute plan that needs to be prepared for ggml_graph_compute()
|
||||||
// since https://github.com/ggerganov/ggml/issues/287
|
// since https://github.com/ggerganov/ggml/issues/287
|
||||||
struct ggml_cplan {
|
struct ggml_cplan {
|
||||||
@ -576,8 +581,8 @@ extern "C" {
|
|||||||
int n_threads;
|
int n_threads;
|
||||||
|
|
||||||
// abort ggml_graph_compute when true
|
// abort ggml_graph_compute when true
|
||||||
bool (*abort_callback)(void * data);
|
ggml_abort_callback abort_callback;
|
||||||
void * abort_callback_data;
|
void * abort_callback_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum ggml_cgraph_eval_order {
|
enum ggml_cgraph_eval_order {
|
||||||
|
@ -156,11 +156,11 @@ static bool ggml_graph_compute_helper(
|
|||||||
struct ggml_cgraph * graph,
|
struct ggml_cgraph * graph,
|
||||||
std::vector<uint8_t> & buf,
|
std::vector<uint8_t> & buf,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
whisper_abort_callback abort_callback,
|
ggml_abort_callback abort_callback,
|
||||||
void * abort_callback_data) {
|
void * abort_callback_data) {
|
||||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||||
|
|
||||||
plan.abort_callback = abort_callback;
|
plan.abort_callback = abort_callback;
|
||||||
plan.abort_callback_data = abort_callback_data;
|
plan.abort_callback_data = abort_callback_data;
|
||||||
|
|
||||||
if (plan.work_size > 0) {
|
if (plan.work_size > 0) {
|
||||||
@ -2130,7 +2130,7 @@ static bool whisper_encode_internal(
|
|||||||
whisper_state & wstate,
|
whisper_state & wstate,
|
||||||
const int mel_offset,
|
const int mel_offset,
|
||||||
const int n_threads,
|
const int n_threads,
|
||||||
whisper_abort_callback abort_callback,
|
ggml_abort_callback abort_callback,
|
||||||
void * abort_callback_data) {
|
void * abort_callback_data) {
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
@ -2561,7 +2561,7 @@ static bool whisper_decode_internal(
|
|||||||
whisper_state & wstate,
|
whisper_state & wstate,
|
||||||
const whisper_batch & batch,
|
const whisper_batch & batch,
|
||||||
const int n_threads,
|
const int n_threads,
|
||||||
whisper_abort_callback abort_callback,
|
ggml_abort_callback abort_callback,
|
||||||
void * abort_callback_data) {
|
void * abort_callback_data) {
|
||||||
const int64_t t_start_us = ggml_time_us();
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
@ -412,11 +412,6 @@ extern "C" {
|
|||||||
// If it returns false, the computation is aborted
|
// If it returns false, the computation is aborted
|
||||||
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
||||||
|
|
||||||
// Abort callback
|
|
||||||
// If not NULL, called before ggml computation
|
|
||||||
// If it returns true, the computation is aborted
|
|
||||||
typedef bool (*whisper_abort_callback)(void * user_data);
|
|
||||||
|
|
||||||
// Logits filter callback
|
// Logits filter callback
|
||||||
// Can be used to modify the logits before sampling
|
// Can be used to modify the logits before sampling
|
||||||
// If not NULL, called after applying temperature to logits
|
// If not NULL, called after applying temperature to logits
|
||||||
@ -513,7 +508,7 @@ extern "C" {
|
|||||||
void * encoder_begin_callback_user_data;
|
void * encoder_begin_callback_user_data;
|
||||||
|
|
||||||
// called each time before ggml computation starts
|
// called each time before ggml computation starts
|
||||||
whisper_abort_callback abort_callback;
|
ggml_abort_callback abort_callback;
|
||||||
void * abort_callback_user_data;
|
void * abort_callback_user_data;
|
||||||
|
|
||||||
// called by each decoder to filter obtained logits
|
// called by each decoder to filter obtained logits
|
||||||
|
Loading…
Reference in New Issue
Block a user