如何用Python通过DBSCAN获取各聚类的中心与半径?
嘿,我来帮你搞定获取聚类中心和半径的问题!针对你用的MeanShift和DBSCAN两种算法,它们的特性不一样,所以获取中心和半径的方式也有区别,我给你一步步说明,还会把代码补全给你参考:
MeanShift:直接获取中心,计算半径
MeanShift算法本身会输出聚类中心,你可以直接通过meanshift.cluster_centers_拿到。至于半径,我们可以计算每个聚类里所有点到对应中心的最大距离,把这个值当作该聚类的半径(能覆盖所有聚类内点的最小圆半径)。
DBSCAN:手动计算中心与半径
DBSCAN是基于密度的聚类算法,没有自带的聚类中心属性。我们需要手动推导:
- 中心:取每个聚类内所有点的坐标均值(也可以用中位数,不过均值更常用)
- 半径:同样取聚类内点到该中心的最大距离
注意:DBSCAN会把噪声点标记为-1,计算的时候一定要把这些点排除掉!
完整代码示例(含可视化)
我把你的代码补全,加上了获取中心、半径的逻辑,还新增了绘制聚类范围的功能:
import numpy as np import matplotlib.pyplot as plt from sklearn import cluster def cluster_plots(data, labels1, labels2, centers1=None, radii1=None, centers2=None, radii2=None, colours1='gray', colours2='gray', title1='DBSCAN Clustering', title2='MeanShift Clustering'): fig, (ax1, ax2) = plt.subplots(1, 2) fig.set_size_inches(12, 5) # 绘制DBSCAN结果+中心半径 ax1.set_title(title1, fontsize=14) ax1.scatter(data[:, 0], data[:, 1], c=labels1, cmap=colours1, alpha=0.6) if centers1 is not None and radii1 is not None: for center, radius in zip(centers1, radii1): # 绘制聚类范围圆 circle = plt.Circle(center, radius, color='red', fill=False, linestyle='--', linewidth=2) ax1.add_patch(circle) # 标记聚类中心 ax1.scatter(center[0], center[1], c='red', marker='x', s=120) # 绘制MeanShift结果+中心半径 ax2.set_title(title2, fontsize=14) ax2.scatter(data[:, 0], data[:, 1], c=labels2, cmap=colours2, alpha=0.6) if centers2 is not None and radii2 is not None: for center, radius in zip(centers2, radii2): circle = plt.Circle(center, radius, color='blue', fill=False, linestyle='--', linewidth=2) ax2.add_patch(circle) ax2.scatter(center[0], center[1], c='blue', marker='x', s=120) plt.tight_layout() plt.show() # 生成随机测试数据 np.random.seed(42) X = np.vstack([np.random.normal(0, 1, (100, 2)), np.random.normal(5, 1, (100, 2)), np.random.normal(10, 1.5, (100, 2))]) # ---------------------- MeanShift 聚类与参数提取 ---------------------- ms_model = cluster.MeanShift() ms_labels = ms_model.fit_predict(X) ms_centers = ms_model.cluster_centers_ # 直接获取聚类中心 # 计算MeanShift每个聚类的半径 ms_radii = [] for idx, center in enumerate(ms_centers): # 筛选当前聚类的所有点 cluster_points = X[ms_labels == idx] # 计算点到中心的欧氏距离,取最大值作为半径 max_distance = np.max(np.linalg.norm(cluster_points - center, axis=1)) ms_radii.append(max_distance) # ---------------------- DBSCAN 聚类与参数提取 ---------------------- db_model = cluster.DBSCAN(eps=1.2, min_samples=5) db_labels = db_model.fit_predict(X) # 过滤噪声点(标签为-1的点),获取有效聚类标签 valid_db_labels = [label for label in np.unique(db_labels) if label != -1] db_centers = [] db_radii = [] for label in valid_db_labels: cluster_points = X[db_labels == label] # 计算聚类中心(坐标均值) cluster_center = np.mean(cluster_points, axis=0) db_centers.append(cluster_center) # 计算聚类半径 max_distance = np.max(np.linalg.norm(cluster_points - cluster_center, axis=1)) db_radii.append(max_distance) # 可视化结果(带中心和半径标记) cluster_plots(X, db_labels, ms_labels, centers1=np.array(db_centers), radii1=db_radii, centers2=np.array(ms_centers), radii2=ms_radii, colours1='viridis', colours2='plasma') # 打印聚类参数 print("=== MeanShift 聚类信息 ===") for i, (center, radius) in enumerate(zip(ms_centers, ms_radii)): print(f"聚类{i+1}: 中心={center.round(2)}, 半径={radius.round(2)}") print("\n=== DBSCAN 聚类信息 ===") for i, (center, radius) in enumerate(zip(db_centers, db_radii)): print(f"聚类{i+1}: 中心={center.round(2)}, 半径={radius.round(2)}")
关键逻辑说明
- MeanShift:利用算法自带的
cluster_centers_属性直接拿中心,半径通过聚类内点到中心的最大距离计算,确保能覆盖所有点。 - DBSCAN:先过滤噪声点,再对每个有效聚类计算均值中心,半径同样取最大距离。
- 可视化部分用
plt.Circle绘制了聚类的范围圆,方便你直观验证结果。
内容的提问来源于stack exchange,提问作者Luo Zin-Han




