mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-18 15:47:08 +02:00
ggml : add ggml_repeat_4d (llama/13824)
This commit is contained in:
parent
ad433403ce
commit
4dfb2c2215
@ -935,6 +935,15 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// repeat a to the specified shape
|
||||
GGML_API struct ggml_tensor * ggml_repeat_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne2,
|
||||
int64_t ne3);
|
||||
|
||||
// sums repetitions in a into shape of b
|
||||
GGML_API struct ggml_tensor * ggml_repeat_back(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -2319,6 +2319,26 @@ struct ggml_tensor * ggml_repeat(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_repeat_4d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
|
||||
const bool can_repeat = ggml_is_empty(a) || (
|
||||
(ne0 % a->ne[0] == 0) &&
|
||||
(ne1 % a->ne[1] == 0) &&
|
||||
(ne2 % a->ne[2] == 0) &&
|
||||
(ne3 % a->ne[3] == 0)
|
||||
);
|
||||
GGML_ASSERT(can_repeat);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
|
||||
|
||||
result->op = GGML_OP_REPEAT;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_repeat_back
|
||||
|
||||
struct ggml_tensor * ggml_repeat_back(
|
||||
|
Loading…
x
Reference in New Issue
Block a user