mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 16:44:13 +01:00
ggml : support forward pass broadcasting in ggml_sub (ggml/914)
* ggml: support forward pass broadcasting in ggml_sub Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> * Use assert instead of GGML_ASSERT in ggml_compute_forward_sub_f32 The check is already performed in ggml_sub_impl Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> --------- Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com>
This commit is contained in:
parent
9b1788483c
commit
993f0df419
@ -4661,11 +4661,13 @@ static struct ggml_tensor * ggml_sub_impl(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
GGML_ASSERT(ggml_can_repeat(b, a));
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad || b->grad)) {
|
||||
// TODO: support backward pass for broadcasting
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
@ -10104,11 +10106,10 @@ static void ggml_compute_forward_sub_f32(
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
@ -10117,40 +10118,55 @@ static void ggml_compute_forward_sub_f32(
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int ir = 0; ir < nr; ++ir) {
|
||||
// src0, src1 and dst are same shape => same indices
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
for (int64_t r = 0; r < nr0; ++r) {
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_vsub(
|
||||
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
|
||||
ne0);
|
||||
vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
|
||||
#else
|
||||
ggml_vec_sub_f32(ne0,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
|
||||
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
||||
ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
||||
#endif
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
for (int ir = 0; ir < nr; ++ir) {
|
||||
// src0, src1 and dst are same shape => same indices
|
||||
const int i3 = ir/(ne2*ne1);
|
||||
const int i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
|
||||
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
const int64_t i10 = i0 % ne10;
|
||||
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
|
||||
|
||||
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user