diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2a5463fd..6896a08b 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -6648,6 +6648,135 @@ static void ggml_compute_forward_repeat_back(
 
 // ggml_compute_forward_concat
 
+static void ggml_compute_forward_concat_any(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    const size_t len = ggml_type_size(src0->type);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const char * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;
+                    } else {
+                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
+                    }
+
+                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
+
+                    memcpy(y, x, len);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat_i8(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const int8_t * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const ggml_fp16_t * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
 static void ggml_compute_forward_concat_f32(
     const struct ggml_compute_params * params,
     struct ggml_tensor * dst) {
@@ -6655,7 +6784,7 @@ static void ggml_compute_forward_concat_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -6698,6 +6827,16 @@ static void ggml_compute_forward_concat(
     const struct ggml_tensor * src0 = dst->src[0];
 
     switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_I16:
+            {
+                ggml_compute_forward_concat_f16(params, dst);
+            } break;
+        case GGML_TYPE_I8:
+            {
+                ggml_compute_forward_concat_i8(params, dst);
+            } break;
         case GGML_TYPE_F32:
         case GGML_TYPE_I32:
             {
@@ -6705,7 +6844,7 @@ static void ggml_compute_forward_concat(
             } break;
         default:
             {
-                GGML_ABORT("fatal error");
+                ggml_compute_forward_concat_any(params, dst);
             }
     }
 }
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 08424033..89409bb0 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -2332,6 +2332,7 @@ struct ggml_tensor * ggml_concat(
     struct ggml_tensor  * b,
     int                   dim) {
     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+    GGML_ASSERT(a->type == b->type);
 
     int64_t ne[GGML_MAX_DIMS];
     for (int d = 0; d < GGML_MAX_DIMS; ++d) {