From 6394c906af6540ea402ba1ad6837452b0b6aa7e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Jan 2023 19:20:18 +0200 Subject: [PATCH] ggml : fix running tasks with variable number of threads --- ggml.c | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/ggml.c b/ggml.c index e627164a..058241e7 100644 --- a/ggml.c +++ b/ggml.c @@ -4745,7 +4745,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( // TODO: do not support transposed src1 assert(nb10/2 == sizeof(ggml_fp16_t)); - // parallelize by src0 rows using ggml_vec_dot_f32 + // parallelize by src0 rows using ggml_vec_dot_f16 // total rows in src0 const int nr = ne01*ne02*ne03; @@ -4773,7 +4773,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int i3 = i03; ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + (i13*ne12*ne11 + i12*ne11 + 0)*ne00; + ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); @@ -7142,7 +7142,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } if (state->node) { - ggml_compute_forward(&state->params, state->node); + if (state->params.ith < state->params.nth) { + ggml_compute_forward(&state->params, state->node); + } state->node = NULL; } else { break; @@ -7236,9 +7238,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_MUL_MAT: { - // TODO: use different scheduling for different matrix sizes node->n_tasks = n_threads; + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src0); + //const int nr1 = ggml_nrows(node->src1); + + //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); + size_t cur = 0; // TODO: better way to determine if the matrix is transposed @@ -7422,7 +7430,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) workers[j].params = (struct ggml_compute_params) { .type = GGML_TASK_COMPUTE, .ith = j + 1, - .nth = n_threads, + .nth = node->n_tasks, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wdata = cgraph->work ? cgraph->work->data : NULL, }; @@ -7477,7 +7485,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) workers[j].params = (struct ggml_compute_params) { .type = GGML_TASK_FINALIZE, .ith = j + 1, - .nth = n_threads, + .nth = node->n_tasks, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wdata = cgraph->work ? cgraph->work->data : NULL, };