mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-24 14:58:43 +01:00
ggml : implement ggml_compute_forward_dup_f16() special cases
This commit is contained in:
parent
32fbc8cd04
commit
a7047b2a28
90
ggml.c
90
ggml.c
@ -3178,23 +3178,97 @@ void ggml_compute_forward_dup_f16(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
//const int ne00 = src0->ne[0];
|
const int ne00 = src0->ne[0];
|
||||||
//const int ne01 = src0->ne[1];
|
const int ne01 = src0->ne[1];
|
||||||
//const int ne02 = src0->ne[2];
|
const int ne02 = src0->ne[2];
|
||||||
//const int ne03 = src0->ne[3];
|
const int ne03 = src0->ne[3];
|
||||||
|
|
||||||
//const size_t nb00 = src0->nb[0];
|
const size_t nb00 = src0->nb[0];
|
||||||
//const size_t nb01 = src0->nb[1];
|
const size_t nb01 = src0->nb[1];
|
||||||
//const size_t nb02 = src0->nb[2];
|
const size_t nb02 = src0->nb[2];
|
||||||
//const size_t nb03 = src0->nb[3];
|
const size_t nb03 = src0->nb[3];
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
||||||
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
|
||||||
|
if (dst->type == GGML_TYPE_F16) {
|
||||||
|
int id = 0;
|
||||||
|
const size_t rs = ne00*nb00;
|
||||||
|
|
||||||
|
for (int i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int i01 = 0; i01 < ne01; i01++) {
|
||||||
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||||
|
char * dst_ptr = (char *) dst->data + id*rs;
|
||||||
|
|
||||||
|
memcpy(dst_ptr, src0_ptr, rs);
|
||||||
|
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (dst->type == GGML_TYPE_F32) {
|
||||||
|
int id = 0;
|
||||||
|
float * dst_ptr = (float *) dst->data;
|
||||||
|
|
||||||
|
for (int i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int i01 = 0; i01 < ne01; i01++) {
|
||||||
|
for (int i00 = 0; i00 < ne00; i00++) {
|
||||||
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
|
||||||
|
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
GGML_ASSERT(false); // TODO: implement
|
GGML_ASSERT(false); // TODO: implement
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
//printf("%s: this is not optimal - fix me\n", __func__);
|
||||||
|
|
||||||
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
|
int id = 0;
|
||||||
|
float * dst_ptr = (float *) dst->data;
|
||||||
|
|
||||||
|
for (int i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int i01 = 0; i01 < ne01; i01++) {
|
||||||
|
for (int i00 = 0; i00 < ne00; i00++) {
|
||||||
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
|
||||||
|
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (dst->type == GGML_TYPE_F16) {
|
||||||
|
int id = 0;
|
||||||
|
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
||||||
|
|
||||||
|
for (int i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int i01 = 0; i01 < ne01; i01++) {
|
||||||
|
for (int i00 = 0; i00 < ne00; i00++) {
|
||||||
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
|
||||||
|
dst_ptr[id] = *src0_ptr;
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false); // TODO: implement
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_compute_forward_dup_f32(
|
void ggml_compute_forward_dup_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
|
Loading…
Reference in New Issue
Block a user