diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index b8366b79f..105082cba 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -618,6 +618,8 @@ int main(int argc, char ** argv) {
                 fprintf(stderr, "%s: failed to process audio\n", argv[0]);
                 return 10;
             }
+
+            whisper_full_cluster_segments(ctx);
         }
 
         // output stuff
diff --git a/whisper.cpp b/whisper.cpp
index 04cbc36b2..4c208b9b4 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -603,6 +603,8 @@ struct whisper_context {
     // [EXPERIMENTAL] speed-up techniques
     int32_t exp_n_audio_ctx; // 0 - use default
 
+    std::vector<float> audio_embd;
+
     void use_buf(struct ggml_context * ctx, int i) {
 #if defined(WHISPER_USE_SCRATCH)
         size_t last_size = 0;
@@ -1707,18 +1709,34 @@ static bool whisper_encode(
     }
 
     // cur
-    //{
-    //    printf("ne0 = %d\n", cur->ne[0]);
-    //    printf("ne1 = %d\n", cur->ne[1]);
-    //    for (int i = 0; i < 10; ++i) {
-    //        printf("%8.4f ", ((float *)(cur->data))[i]);
-    //    }
-    //    printf("... ");
-    //    for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
-    //        printf("%8.4f ", ((float *)(cur->data))[i]);
-    //    }
-    //    printf("\n");
-    //}
+    {
+        //printf("ne0 = %d\n", cur->ne[0]);
+        //printf("ne1 = %d\n", cur->ne[1]);
+        //for (int i = 0; i < 10; ++i) {
+        //    printf("%8.4f ", ((float *)(cur->data))[i]);
+        //}
+        //printf("... ");
+        //for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
+        //    printf("%8.4f ", ((float *)(cur->data))[i]);
+        //}
+        //printf("\n");
+    }
+
+    {
+        const int i0 = std::min(mel_offset, mel_inp.n_len);
+        const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
+
+        printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
+
+        wctx.audio_embd.clear();
+        wctx.audio_embd.resize(cur->ne[0], 0.0f);
+        for (int j = 0; j < cur->ne[0]; ++j) {
+            for (int i = i0; i < i1; ++i) {
+                wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
+            }
+            wctx.audio_embd[j] /= (i1 - i0);
+        }
+    }
 
     // pre-compute cross-attention memory
     {
@@ -4806,3 +4824,129 @@ static void whisper_exp_compute_token_level_timestamps(
     //    }
     //}
 }
+
+//
+// diarization stuff
+//
+
+void whisper_full_cluster_segments(struct whisper_context * ctx) {
+    const int n_segments = ctx->result_all.size();
+    printf("%s: clustering %d segments\n", __func__, n_segments);
+
+    const auto mel_len_save = ctx->mel.n_len;
+    printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
+
+    std::vector<std::vector<float>> features(n_segments);
+
+    for (int i = 0; i < n_segments; ++i) {
+        const auto & segment_i = ctx->result_all[i];
+        printf("%s: segment %d: t0 = %d, t1 = %d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
+
+        ctx->mel.n_len = segment_i.t1;
+        whisper_encode(ctx, segment_i.t0, 4);
+
+        features[i] = ctx->audio_embd;
+    }
+
+    const int n_features = features[0].size();
+
+    // fuzzy c-means clustering
+    const int n_clusters = 4;
+
+    std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
+    std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
+
+    // initialize the centroids
+    for (int i = 0; i < n_clusters; ++i) {
+        for (int j = 0; j < n_features; ++j) {
+            centroids[i][j] = features[i][j];
+        }
+    }
+
+    // initialize the membership
+    for (int i = 0; i < n_segments; ++i) {
+        membership[i][i % n_clusters] = 1.0;
+    }
+
+    // iterate
+    for (int i = 0; i < 100; ++i) {
+        // update the centroids
+        for (int j = 0; j < n_clusters; ++j) {
+            for (int k = 0; k < n_features; ++k) {
+                centroids[j][k] = 0.0;
+            }
+        }
+
+        for (int j = 0; j < n_segments; ++j) {
+            for (int k = 0; k < n_clusters; ++k) {
+                for (int l = 0; l < n_features; ++l) {
+                    centroids[k][l] += membership[j][k]*features[j][l];
+                }
+            }
+        }
+
+        for (int j = 0; j < n_clusters; ++j) {
+            float sum = 0.0;
+            for (int k = 0; k < n_segments; ++k) {
+                sum += membership[k][j];
+            }
+
+            for (int k = 0; k < n_features; ++k) {
+                centroids[j][k] /= sum;
+            }
+        }
+
+        // update the membership
+        for (int j = 0; j < n_segments; ++j) {
+            for (int k = 0; k < n_clusters; ++k) {
+                float sum = 0.0;
+                for (int l = 0; l < n_clusters; ++l) {
+                    //sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
+
+                    // use the euclidean distance
+                    double d0 = 0.0;
+                    for (int m = 0; m < n_features; ++m) {
+                        d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
+                    }
+                    d0 = std::sqrt(d0);
+
+                    double d1 = 0.0;
+                    for (int m = 0; m < n_features; ++m) {
+                        d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
+                    }
+                    d1 = std::sqrt(d1);
+                    if (d1 == 0.0) {
+                        sum += 1.0;
+                    } else {
+                        sum += std::pow(d0/d1, 2.0/(2.0 - 1.0));
+                    }
+                }
+
+                membership[j][k] = 1.0/sum;
+            }
+        }
+
+        // print the membership
+        for (int i = 0; i < n_segments; ++i) {
+            printf("%s: membership %d: ", __func__, i);
+            for (int j = 0; j < n_clusters; ++j) {
+                printf("%f ", membership[i][j]);
+            }
+            printf(" '%s'\n", ctx->result_all[i].text.c_str());
+        }
+        printf("----------------\n");
+    }
+
+    // print the centroids
+    //for (int i = 0; i < n_clusters; ++i) {
+    //    printf("%s: centroid %d: ", __func__, i);
+    //    for (int j = 0; j < n_features; ++j) {
+    //        printf("%f ", centroids[i][j]);
+    //    }
+    //    printf("\n");
+    //}
+
+    // restore the mel length
+    ctx->mel.n_len = mel_len_save;
+}
+
diff --git a/whisper.h b/whisper.h
index 7eece797c..9e40e702c 100644
--- a/whisper.h
+++ b/whisper.h
@@ -372,6 +372,10 @@ extern "C" {
     WHISPER_API int whisper_bench_memcpy(int n_threads);
     WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
 
+    // Temporary experimental API
+
+    WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
+
 #ifdef __cplusplus
 }
 #endif