You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

K-Means聚类运行耗时过长问题咨询

K-Means聚类运行耗时过长问题咨询

Hey there! First off, 65+ minutes for this K-Means run on your 400k-row dataset does feel slower than expected, especially when you're just trying to nail down the optimal k. Let's break down what's probably dragging things out and how you can speed this up quickly.

首先,找到耗时的核心原因

你的代码里最吃时间的绝对是**silhouette_score的计算**——这个指标要对比每个样本和同簇、异簇样本的平均距离,时间复杂度是O(n²),40万样本下这个计算量会爆炸式增长,大概率是它让你的程序卡了这么久。

另外,虽然你用了PCA降维,但如果降维后的维度还比较高(比如20+维),也会额外增加KMeans和Silhouette计算的负担。


给你几个具体的优化方案(附代码修改)

我直接针对你的代码给可落地的修改建议,改完后速度应该能提升一个数量级:

1. 优先优化Silhouette Score的计算

你完全不需要给每个k值都计算全量样本的Silhouette Score,有两个高效思路:

  • 先缩窄k的范围再计算:先通过肘部法拿到大致趋势,再挑2-3个看起来最优的k(比如肘部附近的数值)来算Silhouette,而不是每个k都跑一遍
  • 用采样计算替代全量计算:取10%-20%的样本计算Silhouette,结果足够用来判断最优k,速度却能快10倍以上

2. 换成MiniBatchKMeans替代普通KMeans

针对大数据集,sklearn提供了MiniBatchKMeans,它用小批次样本迭代训练,速度比标准KMeans快很多,结果也和标准KMeans非常接近,完美适配你现在找最优k的场景。

修改后的代码示例

from sklearn.cluster import MiniBatchKMeans
from sklearn.utils import resample
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score

sse = []
silhouette_scores = []
k_values = [3, 5, 7, 9]
# 采样比例,取10%的样本计算Silhouette,大幅提速
sample_ratio = 0.1

for k in k_values:
    # 替换为MiniBatchKMeans,适配大数据集
    kmeans = MiniBatchKMeans(n_clusters=k, random_state=42, n_init=1, init='k-means++', batch_size=2048)
    kmeans.fit(x_pca)
    sse.append(kmeans.inertia_)
    
    # 对样本和标签同步采样,避免数据不匹配
    if sample_ratio < 1.0:
        x_sample, labels_sample = resample(x_pca, kmeans.labels_, n_samples=int(len(x_pca)*sample_ratio), random_state=42)
        sil_score = silhouette_score(x_sample, labels_sample)
    else:
        sil_score = silhouette_score(x_pca, kmeans.labels_)
    silhouette_scores.append(sil_score)

# 肘部法画图(代码不变)
plt.figure(figsize=(10, 6))
plt.plot(k_values, sse, marker='o')
plt.title('Elbow Method for Optimal k')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Sum of Squared Errors (SSE)')
plt.show()

# Silhouette分数画图(代码不变)
plt.figure(figsize=(10, 6))
plt.plot(k_values, silhouette_scores, marker='o')
plt.title('Silhouette Score for Optimal k')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Silhouette Score')
plt.show()

一些额外的小建议

  1. 先做小范围测试:比如先单独跑k=3的循环,看看KMeans拟合和Silhouette计算各花了多久,精准定位耗时点
  2. 检查PCA降维维度:如果降维后还剩15+维,可以尝试再降一点(比如到10维以内),用explained_variance_ratio_看看累计方差占比,确保不损失太多关键信息
  3. 给循环加进度条:装个tqdm库,把循环改成for k in tqdm(k_values):,这样能实时看到每个k的运行进度,不会以为程序卡崩了
  4. 关于n_init=1:这个设置没问题,为了快速找k,用1次初始化足够;如果是最终模型训练,再调大n_init保证稳定性就行

按照这些方法改完,应该能把运行时间从几十分钟压缩到几分钟,甚至更短!

备注:内容来源于stack exchange,提问作者Joud

火山引擎 最新活动