You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何用Python通过DBSCAN获取各聚类的中心与半径?

嘿,我来帮你搞定获取聚类中心和半径的问题!针对你用的MeanShift和DBSCAN两种算法,它们的特性不一样,所以获取中心和半径的方式也有区别,我给你一步步说明,还会把代码补全给你参考:


MeanShift:直接获取中心,计算半径

MeanShift算法本身会输出聚类中心,你可以直接通过meanshift.cluster_centers_拿到。至于半径,我们可以计算每个聚类里所有点到对应中心的最大距离,把这个值当作该聚类的半径(能覆盖所有聚类内点的最小圆半径)。

DBSCAN:手动计算中心与半径

DBSCAN是基于密度的聚类算法,没有自带的聚类中心属性。我们需要手动推导:

  • 中心:取每个聚类内所有点的坐标均值(也可以用中位数,不过均值更常用)
  • 半径:同样取聚类内点到该中心的最大距离
    注意:DBSCAN会把噪声点标记为-1,计算的时候一定要把这些点排除掉!

完整代码示例(含可视化)

我把你的代码补全,加上了获取中心、半径的逻辑,还新增了绘制聚类范围的功能:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import cluster

def cluster_plots(data, labels1, labels2, centers1=None, radii1=None, centers2=None, radii2=None, 
                  colours1='gray', colours2='gray', title1='DBSCAN Clustering', title2='MeanShift Clustering'):
    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.set_size_inches(12, 5)
    
    # 绘制DBSCAN结果+中心半径
    ax1.set_title(title1, fontsize=14)
    ax1.scatter(data[:, 0], data[:, 1], c=labels1, cmap=colours1, alpha=0.6)
    if centers1 is not None and radii1 is not None:
        for center, radius in zip(centers1, radii1):
            # 绘制聚类范围圆
            circle = plt.Circle(center, radius, color='red', fill=False, linestyle='--', linewidth=2)
            ax1.add_patch(circle)
            # 标记聚类中心
            ax1.scatter(center[0], center[1], c='red', marker='x', s=120)
    
    # 绘制MeanShift结果+中心半径
    ax2.set_title(title2, fontsize=14)
    ax2.scatter(data[:, 0], data[:, 1], c=labels2, cmap=colours2, alpha=0.6)
    if centers2 is not None and radii2 is not None:
        for center, radius in zip(centers2, radii2):
            circle = plt.Circle(center, radius, color='blue', fill=False, linestyle='--', linewidth=2)
            ax2.add_patch(circle)
            ax2.scatter(center[0], center[1], c='blue', marker='x', s=120)
    
    plt.tight_layout()
    plt.show()

# 生成随机测试数据
np.random.seed(42)
X = np.vstack([np.random.normal(0, 1, (100, 2)),
               np.random.normal(5, 1, (100, 2)),
               np.random.normal(10, 1.5, (100, 2))])

# ---------------------- MeanShift 聚类与参数提取 ----------------------
ms_model = cluster.MeanShift()
ms_labels = ms_model.fit_predict(X)
ms_centers = ms_model.cluster_centers_  # 直接获取聚类中心

# 计算MeanShift每个聚类的半径
ms_radii = []
for idx, center in enumerate(ms_centers):
    # 筛选当前聚类的所有点
    cluster_points = X[ms_labels == idx]
    # 计算点到中心的欧氏距离,取最大值作为半径
    max_distance = np.max(np.linalg.norm(cluster_points - center, axis=1))
    ms_radii.append(max_distance)

# ---------------------- DBSCAN 聚类与参数提取 ----------------------
db_model = cluster.DBSCAN(eps=1.2, min_samples=5)
db_labels = db_model.fit_predict(X)

# 过滤噪声点(标签为-1的点),获取有效聚类标签
valid_db_labels = [label for label in np.unique(db_labels) if label != -1]
db_centers = []
db_radii = []

for label in valid_db_labels:
    cluster_points = X[db_labels == label]
    # 计算聚类中心(坐标均值)
    cluster_center = np.mean(cluster_points, axis=0)
    db_centers.append(cluster_center)
    # 计算聚类半径
    max_distance = np.max(np.linalg.norm(cluster_points - cluster_center, axis=1))
    db_radii.append(max_distance)

# 可视化结果(带中心和半径标记)
cluster_plots(X, db_labels, ms_labels, 
              centers1=np.array(db_centers), radii1=db_radii,
              centers2=np.array(ms_centers), radii2=ms_radii,
              colours1='viridis', colours2='plasma')

# 打印聚类参数
print("=== MeanShift 聚类信息 ===")
for i, (center, radius) in enumerate(zip(ms_centers, ms_radii)):
    print(f"聚类{i+1}: 中心={center.round(2)}, 半径={radius.round(2)}")

print("\n=== DBSCAN 聚类信息 ===")
for i, (center, radius) in enumerate(zip(db_centers, db_radii)):
    print(f"聚类{i+1}: 中心={center.round(2)}, 半径={radius.round(2)}")

关键逻辑说明

  1. MeanShift:利用算法自带的cluster_centers_属性直接拿中心,半径通过聚类内点到中心的最大距离计算,确保能覆盖所有点。
  2. DBSCAN:先过滤噪声点,再对每个有效聚类计算均值中心,半径同样取最大距离。
  3. 可视化部分用plt.Circle绘制了聚类的范围圆,方便你直观验证结果。

内容的提问来源于stack exchange,提问作者Luo Zin-Han

火山引擎 最新活动