Matplotlib:如何为每个散点图添加对应y值的图例?
给鸢尾花散点子图添加类别图例的解决方案
嘿,我来帮你搞定这个问题!你现在的代码用c=y给散点着色,但matplotlib不会自动为这种数值型颜色映射生成对应类别的图例——因为它只是把y当成颜色值,而不是分类标签。我们可以通过按类别单独绘制散点并指定标签,或者手动创建图例元素的方式来实现需求,下面给你两种可行的方案:
方案一:纯Matplotlib手动按类别绘制(最直观)
我们可以循环每个鸢尾花类别,针对每个类别单独绘制散点,并指定对应的图例标签,最后在每个子图上调用legend()生成图例。修改后的完整代码如下:
import matplotlib.pyplot as plt import numpy as np import seaborn as sns import pandas as pd from sklearn import datasets # 加载数据并整理成DataFrame data = datasets.load_iris(return_X_y=False) X = data.data y = data.target names = data.feature_names target_names = data.target_names columns = names + ['target'] df = pd.DataFrame(np.hstack([X, y.reshape(-1, 1)]), columns=columns) df['target_names'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'}) indexes = df.index.tolist() # 创建子图 fig, axes = plt.subplots(2, 2, figsize=(12, 8)) # 遍历每个特征和对应的子图位置 features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] subplot_positions = [(0,0), (0,1), (1,0), (1,1)] for feature, (row, col) in zip(features, subplot_positions): ax = axes[row, col] # 按类别循环绘制散点,每个类别指定标签 for target_val, name in zip([0,1,2], target_names): # 筛选当前类别的数据 mask = df['target'] == target_val ax.scatter(indexes[mask], df.loc[mask, feature], label=name) ax.set_title(feature) ax.legend(title='Iris Species') # 添加图例并设置标题 plt.tight_layout() # 自动调整子图间距 plt.show()
代码说明:
- 用
map()方法简化了target到target_names的映射(比原来的三次loc更简洁) - 循环每个特征和对应的子图,然后针对每个鸢尾花类别单独绘制散点,通过
label参数指定图例显示的名称 - 每个子图调用
legend()生成图例,还可以通过title参数给图例加标题,让图表更清晰
方案二:用Seaborn简化绘制(更高效)
如果你不介意用Seaborn来辅助绘图,代码会更简洁,Seaborn会自动处理图例:
import matplotlib.pyplot as plt import numpy as np import seaborn as sns import pandas as pd from sklearn import datasets # 加载数据并整理成DataFrame data = datasets.load_iris(return_X_y=False) X = data.data y = data.target names = data.feature_names target_names = data.target_names columns = names + ['target'] df = pd.DataFrame(np.hstack([X, y.reshape(-1, 1)]), columns=columns) df['target_names'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'}) indexes = df.index.tolist() # 创建子图 fig, axes = plt.subplots(2, 2, figsize=(12, 8)) features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] subplot_positions = [(0,0), (0,1), (1,0), (1,1)] for feature, (row, col) in zip(features, subplot_positions): ax = axes[row, col] # 用seaborn的scatterplot自动生成图例 sns.scatterplot(x=indexes, y=feature, hue='target_names', data=df, ax=ax) ax.set_title(feature) plt.tight_layout() plt.show()
代码说明:
- Seaborn的
scatterplot通过hue参数指定分类列,会自动生成对应的图例,不需要手动循环绘制 - 代码量更少,图表样式也更美观,适合快速生成带分类图例的可视化
两种方案都能实现你想要的效果,你可以根据自己的习惯选择~
内容的提问来源于stack exchange,提问作者harp1814




