Matplotlib绘制带网格线的热力图时网格线缺失问题
Matplotlib绘制带网格线的热力图时网格线缺失问题
我尝试绘制带网格线的热力图,以下是我参考相关内容改编的代码:
# Plot a heatmap with gridlines import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from functional import seq arr = np.random.randn(3, 20) plt.tight_layout() ax = plt.subplot(111) ax.imshow(arr, cmap='viridis') xr = ax.get_xlim() yr = ax.get_ylim() ax.set_xticks(np.arange(max(xr))-0.5, minor=True) ax.set_yticks(np.arange(max(yr))-0.5, minor=True) ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1) ax.tick_params(which='major', bottom=False, left=False) ax.tick_params(which='minor', bottom=False, left=False) ax.set_xticklabels([]) ax.set_yticklabels([]) for spine in ax.spines.values(): spine.set_visible(False) plt.show()
我得到了这样的热力图(裁剪后):

可以看到第3列、第8列以及倒数第3列之后都没有垂直网格线。因为我是通过次要刻度来绘制网格线的,所以我尝试打印x轴刻度,并且不隐藏标签重新绘图:
print('xticks:', np.arange(max(xr))-0.5)
输出结果:
xticks: [-0.5 0.5 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5 10.5 11.5 12.5 13.5 14.5 15.5 16.5 17.5 18.5]
对应的绘图:

从输出结果看需要的刻度都存在,但网格线还是缺失,这是怎么回事呢?
问题原因分析
你遇到的问题根源在于依赖ax.get_xlim()的最大值生成次要刻度,这种方式不够可靠。对于20列的数组,imshow默认的x轴范围是(-0.5, 19.5),max(xr)得到19.5,而np.arange(19.5)会生成从0到19的整数序列,减0.5后得到的刻度是-0.5, 0.5, ..., 18.5,这就漏掉了最右侧的19.5刻度,导致最后一个单元格的右边界没有网格线。而中间部分的网格线缺失,则是浮点数精度或np.arange的行为导致的刻度序列覆盖不完整。
解决方案:基于数组形状生成刻度
直接根据数组的列数和行数来生成次要刻度,这样能精准覆盖所有单元格的边界,避免依赖xlim/ylim的不确定性:
修改代码中设置次要刻度的部分,替换为以下两行:
ax.set_xticks(np.arange(-0.5, arr.shape[1], 1), minor=True) ax.set_yticks(np.arange(-0.5, arr.shape[0], 1), minor=True)
修正后的完整代码
# Plot a heatmap with gridlines import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from functional import seq arr = np.random.randn(3, 20) plt.tight_layout() ax = plt.subplot(111) ax.imshow(arr, cmap='viridis') # 基于数组形状生成次要刻度,覆盖所有单元格边界 ax.set_xticks(np.arange(-0.5, arr.shape[1], 1), minor=True) ax.set_yticks(np.arange(-0.5, arr.shape[0], 1), minor=True) ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1) ax.tick_params(which='major', bottom=False, left=False) ax.tick_params(which='minor', bottom=False, left=False) ax.set_xticklabels([]) ax.set_yticklabels([]) for spine in ax.spines.values(): spine.set_visible(False) plt.show()
修改后,所有垂直和水平网格线都会正确绘制出来,完美包裹每个单元格。
更简洁的替代方案
如果你想更高效地实现需求,也可以使用seaborn.heatmap,它自带网格线支持,只需要设置linewidths参数即可,代码更简洁且不会出现网格线缺失问题:
import numpy as np import seaborn as sns import matplotlib.pyplot as plt arr = np.random.randn(3, 20) sns.heatmap(arr, cmap='viridis', linewidths=1, linecolor='black', xticklabels=False, yticklabels=False, cbar=False) plt.tight_layout() plt.show()
备注:内容来源于stack exchange,提问作者Yuxi L




