You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

Matplotlib绘制带网格线的热力图时网格线缺失问题

Matplotlib绘制带网格线的热力图时网格线缺失问题

我尝试绘制带网格线的热力图,以下是我参考相关内容改编的代码:

# Plot a heatmap with gridlines

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functional import seq

arr = np.random.randn(3, 20)

plt.tight_layout()
ax = plt.subplot(111)
ax.imshow(arr, cmap='viridis')

xr = ax.get_xlim()
yr = ax.get_ylim()
ax.set_xticks(np.arange(max(xr))-0.5, minor=True)
ax.set_yticks(np.arange(max(yr))-0.5, minor=True)
ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1)
ax.tick_params(which='major', bottom=False, left=False)
ax.tick_params(which='minor', bottom=False, left=False)
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
    spine.set_visible(False)

plt.show()

我得到了这样的热力图(裁剪后):

异常的热力图

可以看到第3列、第8列以及倒数第3列之后都没有垂直网格线。因为我是通过次要刻度来绘制网格线的,所以我尝试打印x轴刻度,并且不隐藏标签重新绘图:

print('xticks:', np.arange(max(xr))-0.5)

输出结果:

xticks: [-0.5  0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5  9.5 10.5 11.5 12.5
 13.5 14.5 15.5 16.5 17.5 18.5]

对应的绘图:

显示标签的异常热力图

从输出结果看需要的刻度都存在,但网格线还是缺失,这是怎么回事呢?


问题原因分析

你遇到的问题根源在于依赖ax.get_xlim()的最大值生成次要刻度,这种方式不够可靠。对于20列的数组,imshow默认的x轴范围是(-0.5, 19.5)max(xr)得到19.5,而np.arange(19.5)会生成从0到19的整数序列,减0.5后得到的刻度是-0.5, 0.5, ..., 18.5,这就漏掉了最右侧的19.5刻度,导致最后一个单元格的右边界没有网格线。而中间部分的网格线缺失,则是浮点数精度或np.arange的行为导致的刻度序列覆盖不完整。

解决方案:基于数组形状生成刻度

直接根据数组的列数和行数来生成次要刻度,这样能精准覆盖所有单元格的边界,避免依赖xlim/ylim的不确定性:

修改代码中设置次要刻度的部分,替换为以下两行:

ax.set_xticks(np.arange(-0.5, arr.shape[1], 1), minor=True)
ax.set_yticks(np.arange(-0.5, arr.shape[0], 1), minor=True)

修正后的完整代码

# Plot a heatmap with gridlines

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functional import seq

arr = np.random.randn(3, 20)

plt.tight_layout()
ax = plt.subplot(111)
ax.imshow(arr, cmap='viridis')

# 基于数组形状生成次要刻度,覆盖所有单元格边界
ax.set_xticks(np.arange(-0.5, arr.shape[1], 1), minor=True)
ax.set_yticks(np.arange(-0.5, arr.shape[0], 1), minor=True)

ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1)
ax.tick_params(which='major', bottom=False, left=False)
ax.tick_params(which='minor', bottom=False, left=False)
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
    spine.set_visible(False)

plt.show()

修改后,所有垂直和水平网格线都会正确绘制出来,完美包裹每个单元格。

更简洁的替代方案

如果你想更高效地实现需求,也可以使用seaborn.heatmap,它自带网格线支持,只需要设置linewidths参数即可,代码更简洁且不会出现网格线缺失问题:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

arr = np.random.randn(3, 20)
sns.heatmap(arr, cmap='viridis', linewidths=1, linecolor='black', 
            xticklabels=False, yticklabels=False, cbar=False)
plt.tight_layout()
plt.show()

备注:内容来源于stack exchange,提问作者Yuxi L

火山引擎 最新活动