diff --git a/ggml.c b/ggml.c
index d67612c36..54094f04f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -8517,6 +8517,193 @@ enum ggml_opt_result ggml_opt(
 
 ////////////////////////////////////////////////////////////////////////////////
 
+void ggml_svd_reduce_dims(
+        int ne0,
+        int ne1,
+        float * a,
+        int nd) {
+    int n = ne1;
+    int m = ne0;
+
+    float * A = a;
+    float * A0 = (float *) malloc(n * m * sizeof(float));
+
+    // average vector
+    float * M = (float *) malloc(m * sizeof(float));
+
+    {
+        for (int j = 0; j < m; ++j) {
+            M[j] = 0.0f;
+        }
+        for (int i = 0; i < n; ++i) {
+            for (int j = 0; j < m; ++j) {
+                M[j] += A[i * m + j];
+            }
+        }
+        for (int j = 0; j < m; ++j) {
+            M[j] /= (float) n;
+        }
+    }
+
+    // subtract average vector
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < m; ++j) {
+            A[i * m + j] -= M[j];
+        }
+    }
+
+    memcpy(A0, A, n * m * sizeof(float));
+
+    // print A
+    //printf("A:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("col %d : ", i);
+    //    for (int j = 0; j < m; ++j) {
+    //        printf("%9.5f ", A[i * m + j]);
+    //    }
+    //    printf("\n");
+    //}
+    //printf("\n");
+
+    // SVD
+    // A = U * S * V^T
+
+    float * U = (float *) malloc(n * m * sizeof(float));
+    float * S = (float *) malloc(n * sizeof(float));
+    float * V = (float *) malloc(n * n * sizeof(float));
+
+    int lda = m;
+    int ldu = m;
+    int ldvt = n;
+
+    float work_size;
+    int lwork = -1;
+    int info = 0;
+
+    sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, &work_size, &lwork, &info);
+
+    lwork = (int) work_size;
+
+    //printf("work_size = %f, info = %d, lwork = %d\n", work_size, info, lwork);
+
+    float * work = (float *) malloc(lwork * sizeof(float));
+
+    sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, work, &lwork, &info);
+
+    free(work);
+
+    // print U
+    //printf("U:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("col %d : ", i);
+    //    for (int j = 0; j < m; ++j) {
+    //        printf("%9.5f ", U[i * m + j]);
+    //    }
+    //    printf("\n");
+    //}
+    //printf("\n");
+
+    // normalize S
+    {
+        double sum = 0.0;
+        for (int i = 0; i < n; ++i) {
+            sum += S[i];
+        }
+        sum *= sqrt((double) m);
+        for (int i = 0; i < n; ++i) {
+            S[i] /= sum;
+        }
+    }
+
+    // print S
+    //printf("S:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("- %d = %9.5f\n", i, S[i]);
+    //}
+    //printf("\n");
+
+    // print V
+    //printf("V:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("col %d : ", i);
+    //    for (int j = 0; j < n; ++j) {
+    //        printf("%9.5f ", V[i * n + j]);
+    //    }
+    //    printf("\n");
+    //}
+    //printf("\n");
+
+    // print A
+    //printf("A:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("col %d : ", i);
+    //    for (int j = 0; j < m; ++j) {
+    //        printf("%9.5f ", A[i * m + j]);
+    //    }
+    //    printf("\n");
+    //}
+    //printf("\n");
+
+    // compute singular vectors in U
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < m; ++j) {
+            U[i * m + j] *= S[i];
+        }
+    }
+
+    // normalize U
+    for (int i = 0; i < n; ++i) {
+        double sum = 0.0;
+        for (int j = 0; j < m; ++j) {
+            sum += U[i * m + j] * U[i * m + j];
+        }
+        sum = sqrt(sum);
+        for (int j = 0; j < m; ++j) {
+            U[i * m + j] /= sum*sqrt((double) m);
+        }
+    }
+
+    // print U
+    //printf("U:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("col %d : ", i);
+    //    for (int j = 0; j < m; ++j) {
+    //        printf("%9.5f ", U[i * m + j]);
+    //    }
+    //    printf("\n");
+    //}
+    //printf("\n");
+
+
+    // project A0 onto U
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < n; ++j) {
+            A[i * nd + j] = 0.0f;
+            for (int k = 0; k < m; ++k) {
+                A[i * nd + j] += A0[i * m + k] * U[j * m + k];
+            }
+        }
+    }
+
+    // print A
+    //printf("A:\n");
+    //for (int i = 0; i < n; ++i) {
+    //    printf("col %d : ", i);
+    //    for (int j = 0; j < n; ++j) {
+    //        printf("%9.5f ", A[i * n + j]);
+    //    }
+    //    printf("\n");
+    //}
+    //printf("\n");
+
+    free(U);
+    free(S);
+    free(V);
+    free(A0);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
 int ggml_cpu_has_avx(void) {
 #if defined(__AVX__)
     return 1;
diff --git a/ggml.h b/ggml.h
index 18f317bec..e63b28612 100644
--- a/ggml.h
+++ b/ggml.h
@@ -726,6 +726,16 @@ enum ggml_opt_result ggml_opt(
         struct ggml_opt_params params,
         struct ggml_tensor * f);
 
+//
+// Temp stuff
+//
+
+void ggml_svd_reduce_dims(
+        int ne0,
+        int ne1,
+        float * a,
+        int nd);
+
 //
 // system info
 //
diff --git a/whisper.cpp b/whisper.cpp
index 4c208b9b4..0b91a151c 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -603,8 +603,6 @@ struct whisper_context {
     // [EXPERIMENTAL] speed-up techniques
     int32_t exp_n_audio_ctx; // 0 - use default
 
-    std::vector<float> audio_embd;
-
     void use_buf(struct ggml_context * ctx, int i) {
 #if defined(WHISPER_USE_SCRATCH)
         size_t last_size = 0;
@@ -1360,7 +1358,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 static bool whisper_encode(
         whisper_context & wctx,
               const int   mel_offset,
-              const int   n_threads) {
+              const int   n_threads,
+              bool repeat = false) {
     const int64_t t_start_us = ggml_time_us();
 
     const auto & model   = wctx.model;
@@ -1392,9 +1391,24 @@ static bool whisper_encode(
         const int i0 = std::min(mel_offset, mel_inp.n_len);
         const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
 
-        for (int j = 0; j < mel_inp.n_mel; ++j) {
-            for (int i = i0; i < i1; ++i) {
-                dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
+        if (repeat == false) {
+            for (int j = 0; j < mel_inp.n_mel; ++j) {
+                for (int i = i0; i < i1; ++i) {
+                    dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
+                }
+            }
+        } else {
+            for (int j = 0; j < mel_inp.n_mel; ++j) {
+                int k = 0;
+                while (k < 2*n_ctx) {
+                    for (int i = i0; i < i1; ++i) {
+                        dst[j*2*n_ctx + k] = mel_inp.data[j*mel_inp.n_len + i];
+                        k++;
+                        if (k >= 2*n_ctx) {
+                            break;
+                        }
+                    }
+                }
             }
         }
     }
@@ -1722,22 +1736,6 @@ static bool whisper_encode(
         //printf("\n");
     }
 
-    {
-        const int i0 = std::min(mel_offset, mel_inp.n_len);
-        const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
-
-        printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
-
-        wctx.audio_embd.clear();
-        wctx.audio_embd.resize(cur->ne[0], 0.0f);
-        for (int j = 0; j < cur->ne[0]; ++j) {
-            for (int i = i0; i < i1; ++i) {
-                wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
-            }
-            wctx.audio_embd[j] /= (i1 - i0);
-        }
-    }
-
     // pre-compute cross-attention memory
     {
         struct ggml_cgraph gf = {};
@@ -4836,117 +4834,151 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
     const auto mel_len_save = ctx->mel.n_len;
     printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
 
-    std::vector<std::vector<float>> features(n_segments);
+    const int n_ctx   = ctx->model.hparams.n_audio_ctx;
+    const int n_state = ctx->model.hparams.n_audio_state;
+    const int n_layer = ctx->model.hparams.n_audio_layer;
 
-    for (int i = 0; i < n_segments; ++i) {
-        const auto & segment_i = ctx->result_all[i];
-        printf("%s: segment %d: t0 = %d, t1 = %d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
+    for (int il = 0; il < n_layer; ++il) {
+        std::vector<float> embd(n_segments*n_ctx*n_state);
 
-        ctx->mel.n_len = segment_i.t1;
-        whisper_encode(ctx, segment_i.t0, 4);
-
-        features[i] = ctx->audio_embd;
-    }
-
-    const int n_features = features[0].size();
-
-    // fuzzy c-means clustering
-    const int n_clusters = 4;
-
-    std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
-    std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
-
-    // initialize the centroids
-    for (int i = 0; i < n_clusters; ++i) {
-        for (int j = 0; j < n_features; ++j) {
-            centroids[i][j] = features[i][j];
-        }
-    }
-
-    // initialize the membership
-    for (int i = 0; i < n_segments; ++i) {
-        membership[i][i % n_clusters] = 1.0;
-    }
-
-    // iterate
-    for (int i = 0; i < 100; ++i) {
-        // update the centroids
-        for (int j = 0; j < n_clusters; ++j) {
-            for (int k = 0; k < n_features; ++k) {
-                centroids[j][k] = 0.0;
-            }
-        }
-
-        for (int j = 0; j < n_segments; ++j) {
-            for (int k = 0; k < n_clusters; ++k) {
-                for (int l = 0; l < n_features; ++l) {
-                    centroids[k][l] += membership[j][k]*features[j][l];
-                }
-            }
-        }
-
-        for (int j = 0; j < n_clusters; ++j) {
-            float sum = 0.0;
-            for (int k = 0; k < n_segments; ++k) {
-                sum += membership[k][j];
-            }
-
-            for (int k = 0; k < n_features; ++k) {
-                centroids[j][k] /= sum;
-            }
-        }
-
-        // update the membership
-        for (int j = 0; j < n_segments; ++j) {
-            for (int k = 0; k < n_clusters; ++k) {
-                float sum = 0.0;
-                for (int l = 0; l < n_clusters; ++l) {
-                    //sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
-
-                    // use the euclidean distance
-                    double d0 = 0.0;
-                    for (int m = 0; m < n_features; ++m) {
-                        d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
-                    }
-                    d0 = std::sqrt(d0);
-
-                    double d1 = 0.0;
-                    for (int m = 0; m < n_features; ++m) {
-                        d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
-                    }
-                    d1 = std::sqrt(d1);
-                    if (d1 == 0.0) {
-                        sum += 1.0;
-                    } else {
-                        sum += std::pow(d0/d1, 2.0/(2.0 - 1.0));
-                    }
-                }
-
-                membership[j][k] = 1.0/sum;
-            }
-        }
-
-        // print the membership
         for (int i = 0; i < n_segments; ++i) {
-            printf("%s: membership %d: ", __func__, i);
-            for (int j = 0; j < n_clusters; ++j) {
-                printf("%f ", membership[i][j]);
-            }
-            printf(" '%s'\n", ctx->result_all[i].text.c_str());
-        }
-        printf("----------------\n");
-    }
+            const auto & segment_i = ctx->result_all[i];
+            printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
 
-    // print the centroids
-    //for (int i = 0; i < n_clusters; ++i) {
-    //    printf("%s: centroid %d: ", __func__, i);
-    //    for (int j = 0; j < n_features; ++j) {
-    //        printf("%f ", centroids[i][j]);
-    //    }
-    //    printf("\n");
-    //}
+            ctx->mel.n_len = segment_i.t1;
+            whisper_encode(*ctx, segment_i.t0, 7, true);
+
+            const size_t offs = ggml_element_size(ctx->kv_cross.k)*(il*n_ctx*n_state);
+            const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_cross.k->data + offs);
+
+            for (int j = 0; j < n_ctx*n_state; ++j) {
+                embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
+            }
+        }
+
+        const int n_features = 64;
+
+        ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
+
+        std::vector<std::vector<float>> features(n_segments);
+
+        for (int i = 0; i < n_segments; ++i) {
+            features[i].resize(n_features);
+            for (int j = 0; j < n_features; ++j) {
+                features[i][j] = embd[i*n_features + j];
+            }
+        }
+
+        // fuzzy c-means clustering
+        const int n_clusters = 2;
+
+        std::vector<std::vector<float>> centroids(n_clusters, std::vector<float>(n_features, 0.0));
+        std::vector<std::vector<float>> membership(n_segments, std::vector<float>(n_clusters, 0.0));
+
+        // initialize the centroids
+        for (int i = 0; i < n_clusters; ++i) {
+            for (int j = 0; j < n_features; ++j) {
+                centroids[i][j] = features[i][j];
+            }
+        }
+
+        // initialize the membership
+        for (int i = 0; i < n_segments; ++i) {
+            //membership[i][i % n_clusters] = 1.0;
+            for (int j = 0; j < n_clusters; ++j) {
+                membership[i][j] = rand() / (float) RAND_MAX;
+            }
+        }
+
+        const int niter = 10000;
+
+        // iterate
+        for (int i = 0; i < niter; ++i) {
+            // update the centroids
+            for (int j = 0; j < n_clusters; ++j) {
+                for (int k = 0; k < n_features; ++k) {
+                    centroids[j][k] = 0.0;
+                }
+            }
+
+            for (int j = 0; j < n_segments; ++j) {
+                for (int k = 0; k < n_clusters; ++k) {
+                    for (int l = 0; l < n_features; ++l) {
+                        centroids[k][l] += membership[j][k]*features[j][l];
+                    }
+                }
+            }
+
+            for (int j = 0; j < n_clusters; ++j) {
+                float sum = 0.0;
+                for (int k = 0; k < n_segments; ++k) {
+                    sum += membership[k][j];
+                }
+
+                for (int k = 0; k < n_features; ++k) {
+                    centroids[j][k] /= sum;
+                }
+            }
+
+            // update the membership
+            for (int j = 0; j < n_segments; ++j) {
+                for (int k = 0; k < n_clusters; ++k) {
+                    float sum = 0.0;
+                    for (int l = 0; l < n_clusters; ++l) {
+                        //sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
+
+                        // use the euclidean distance
+                        double d0 = 0.0;
+                        for (int m = 0; m < n_features; ++m) {
+                            d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
+                        }
+                        d0 = std::sqrt(d0);
+
+                        double d1 = 0.0;
+                        for (int m = 0; m < n_features; ++m) {
+                            d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
+                        }
+                        d1 = std::sqrt(d1);
+
+                        if (d1 == 0.0) {
+                            sum += 1.0;
+                        } else {
+                            sum += std::pow(d0/d1, 2.0/(1.10 - 1.0));
+                        }
+                    }
+
+                    membership[j][k] = 1.0/sum;
+                }
+            }
+
+            // print the membership
+            if (i == niter - 1) {
+                for (int i = 0; i < n_segments; ++i) {
+                    printf("%s: membership %3d: ", __func__, i);
+                    for (int j = 0; j < n_clusters; ++j) {
+                        printf("%f ", membership[i][j]);
+                    }
+                    printf(" '%s'\n", ctx->result_all[i].text.c_str());
+                    //printf("%s: features      : ", __func__);
+                    //for (int j = 0; j < n_features; ++j) {
+                    //    printf("%8.3f ", features[i][j]);
+                    //}
+                    //printf(" '%s'\n", ctx->result_all[i].text.c_str());
+                }
+                printf("----------------\n");
+            }
+        }
+
+        // print the centroids
+        for (int i = 0; i < n_clusters; ++i) {
+            printf("%s: centroid %d: ", __func__, i);
+            for (int j = 0; j < n_features; ++j) {
+                printf("%f ", centroids[i][j]);
+            }
+            printf("\n");
+        }
+    }
 
     // restore the mel length
     ctx->mel.n_len = mel_len_save;
 }
-