使用TSNE可视化KMeans聚类异常排查及替代可视化方案咨询
关于KMeans聚类TSNE可视化异常的问题解答
先看你的代码里的明显笔误——你定义的聚类模型是kmeanModel,但可视化散点图时用了km.labels_,km这个变量根本没定义啊!这几乎肯定是你看到4条分组线的原因:要么是你之前测试过n_clusters=4的模型,不小心把变量名搞混了;要么就是单纯的拼写错误,导致调用了错误的标签数组。
是可视化问题还是KMeans算法的问题?
大概率是可视化代码的错误,而非KMeans本身的问题。只要你明确指定了n_clusters=3,KMeans只会输出0、1、2三个类别标签(除非你的数据极端特殊,但从你给出的样本数据来看完全不可能)。你可以在拟合模型后加一行代码确认:
print(np.unique(kmeanModel.labels_))
如果输出是[0 1 2],那百分百是可视化时的变量名错误;如果输出有4个值,那说明你可能在某个地方修改了模型参数(比如不小心把n_clusters设成了4)。
另外,你的代码里还有个多余的步骤:你先对整个数据集做了KMeans拟合,然后又做了train_test_split,但后面的降维和可视化根本没用到划分后的数据集,这部分代码可以删掉,避免混淆。
其他常用的聚类可视化方法
除了TSNE,还有这些方法可以直观展示聚类效果:
- PCA散点图:和TSNE类似,但属于线性降维,计算速度更快,适合快速初步观察聚类的分离情况。
- 散点图矩阵(Pair Plot):如果你的特征数量不多(比如≤5个),用Seaborn的
sns.pairplot,把聚类标签作为hue参数,可以清晰看到不同类别在每个特征维度上的分布差异。 - 轮廓系数图:用
sklearn.metrics.silhouette_samples计算每个样本的轮廓系数,然后画出箱线图或条形图,能直观反映聚类的紧凑度和类别间的分离度。 - 平行坐标图:当特征数量较多时,平行坐标可以展示不同类别在各个特征上的取值趋势,帮助你理解类别之间的差异来源。
- 核密度估计图:针对单个特征,画出不同类别的核密度曲线,对比类别的分布形态。
修正后的代码示例
我帮你修正了变量名错误,还加了random_state保证结果可复现,同时去掉了多余的train_test_split:
import pandas as pd import numpy as np import ast from sklearn.cluster import KMeans import matplotlib.pyplot as plt from sklearn.manifold import TSNE from sklearn.decomposition import TruncatedSVD colNames = ['unixTime', 'sampleAmount','Time','samplingRate', 'Data'] data = pd.read_csv("project_fan.csv", sep = ';', error_bad_lines = False, names = colNames) # 将Data列转为列表并取平均值 data['Data'] = data.Data.transform(ast.literal_eval) data['Data'] = data.Data.apply(np.mean) # 拟合KMeans模型,指定random_state保证结果一致 kmeanModel = KMeans(n_clusters = 3, random_state=42) kmeanModel.fit(data) y_labels = kmeanModel.labels_ # 确认聚类标签数量 print("聚类标签唯一值:", np.unique(y_labels)) # 降维:先TruncatedSVD到3维,再TSNE到2维 tfs_reduced = TruncatedSVD(n_components=3, random_state=0).fit_transform(data) tfs_embedded = TSNE(n_components=2, perplexity=40, verbose=2, random_state=42).fit_transform(tfs_reduced) # 可视化 fig = plt.figure(figsize = (10, 10)) plt.scatter(tfs_embedded[:, 0], tfs_embedded[:, 1], marker = "x", c = y_labels) plt.title("TSNE Visualization of KMeans Clustering (3 Clusters)") plt.show()
内容的提问来源于stack exchange,提问作者x89




