diff --git a/ggml.c b/ggml.c
index d2a8c047..f5caeba0 100644
--- a/ggml.c
+++ b/ggml.c
@@ -394,6 +394,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
 static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
 
+ggml_collect_imatrix_t g_imatrix_collect = NULL;
+
+void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) {
+    g_imatrix_collect = imatrix_collect;
+}
+
 static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
     [GGML_TYPE_I8] = {
         .type_name                = "i8",
@@ -9763,6 +9769,10 @@ static void ggml_compute_forward_mul_mat(
     const int ith = params->ith;
     const int nth = params->nth;
 
+    if (ith == 1 && g_imatrix_collect) {
+        g_imatrix_collect(src0, src1);
+    }
+
     const enum ggml_type type = src0->type;
 
     const bool src1_cont = ggml_is_contiguous(src1);
@@ -10066,6 +10076,10 @@ static void ggml_compute_forward_mul_mat_id(
 
         const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
 
+        if (ith == 1 && g_imatrix_collect) {
+            g_imatrix_collect(src0_cur, src1);
+        }
+
         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
diff --git a/ggml.h b/ggml.h
index 93b42a27..4c2ff6c6 100644
--- a/ggml.h
+++ b/ggml.h
@@ -2067,6 +2067,12 @@ extern "C" {
 
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
+    //
+    // Importance matrix
+    //
+    typedef void(*ggml_collect_imatrix_t)(const struct ggml_tensor * src0, const struct ggml_tensor * src1);
+    GGML_API void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect);
+
     //
     // gguf
     //