You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

Matplotlib:如何为每个散点图添加对应y值的图例?

给鸢尾花散点子图添加类别图例的解决方案

嘿,我来帮你搞定这个问题!你现在的代码用c=y给散点着色,但matplotlib不会自动为这种数值型颜色映射生成对应类别的图例——因为它只是把y当成颜色值,而不是分类标签。我们可以通过按类别单独绘制散点并指定标签,或者手动创建图例元素的方式来实现需求,下面给你两种可行的方案:

方案一:纯Matplotlib手动按类别绘制(最直观)

我们可以循环每个鸢尾花类别,针对每个类别单独绘制散点,并指定对应的图例标签,最后在每个子图上调用legend()生成图例。修改后的完整代码如下:

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from sklearn import datasets

# 加载数据并整理成DataFrame
data = datasets.load_iris(return_X_y=False)
X = data.data
y = data.target
names = data.feature_names
target_names = data.target_names
columns = names + ['target']
df = pd.DataFrame(np.hstack([X, y.reshape(-1, 1)]), columns=columns)
df['target_names'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
indexes = df.index.tolist()

# 创建子图
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# 遍历每个特征和对应的子图位置
features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
subplot_positions = [(0,0), (0,1), (1,0), (1,1)]

for feature, (row, col) in zip(features, subplot_positions):
    ax = axes[row, col]
    # 按类别循环绘制散点,每个类别指定标签
    for target_val, name in zip([0,1,2], target_names):
        # 筛选当前类别的数据
        mask = df['target'] == target_val
        ax.scatter(indexes[mask], df.loc[mask, feature], label=name)
    ax.set_title(feature)
    ax.legend(title='Iris Species')  # 添加图例并设置标题

plt.tight_layout()  # 自动调整子图间距
plt.show()

代码说明:

  • map()方法简化了target到target_names的映射(比原来的三次loc更简洁)
  • 循环每个特征和对应的子图,然后针对每个鸢尾花类别单独绘制散点,通过label参数指定图例显示的名称
  • 每个子图调用legend()生成图例,还可以通过title参数给图例加标题,让图表更清晰

方案二:用Seaborn简化绘制(更高效)

如果你不介意用Seaborn来辅助绘图,代码会更简洁,Seaborn会自动处理图例:

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from sklearn import datasets

# 加载数据并整理成DataFrame
data = datasets.load_iris(return_X_y=False)
X = data.data
y = data.target
names = data.feature_names
target_names = data.target_names
columns = names + ['target']
df = pd.DataFrame(np.hstack([X, y.reshape(-1, 1)]), columns=columns)
df['target_names'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
indexes = df.index.tolist()

# 创建子图
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
subplot_positions = [(0,0), (0,1), (1,0), (1,1)]

for feature, (row, col) in zip(features, subplot_positions):
    ax = axes[row, col]
    # 用seaborn的scatterplot自动生成图例
    sns.scatterplot(x=indexes, y=feature, hue='target_names', data=df, ax=ax)
    ax.set_title(feature)

plt.tight_layout()
plt.show()

代码说明:

  • Seaborn的scatterplot通过hue参数指定分类列,会自动生成对应的图例,不需要手动循环绘制
  • 代码量更少,图表样式也更美观,适合快速生成带分类图例的可视化

两种方案都能实现你想要的效果,你可以根据自己的习惯选择~

内容的提问来源于stack exchange,提问作者harp1814

火山引擎 最新活动