From d352dbd163019d12b2234e60b7d161fe4192aee9 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 20 Feb 2024 14:17:09 -0500 Subject: [PATCH] ggml : fix conv_2d batch mode (ggml/737) Co-authored-by: bssrdf --- ggml.c | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 4ee2c5e1..86e3bc2e 100644 --- a/ggml.c +++ b/ggml.c @@ -5629,7 +5629,9 @@ struct ggml_tensor * ggml_conv_2d( ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW] + return result; }