You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch中GRU层h_n输出的维度排布及正向隐藏状态提取方法咨询

PyTorch中GRU层h_n输出的维度排布及正向隐藏状态提取方法咨询

嗨,看了你这段刚接触PyTorch一个月写的GRU代码,我来帮你理清楚双向GRU中hn的维度排布问题~

首先,你观察到hn的形状是torch.Size([4, 64, 32]),这个确实是2层GRU × 2个方向的结果。关于它的排布顺序,PyTorch的规则是:按层划分,每一层的正向隐藏状态在前,反向隐藏状态在后,并不是交替排列的。具体到你的两层双向GRU,hn的每个索引对应的含义是:

  • 索引0:第1层GRU的正向最后时刻隐藏状态
  • 索引1:第1层GRU的反向最后时刻隐藏状态
  • 索引2:第2层GRU的正向最后时刻隐藏状态
  • 索引3:第2层GRU的反向最后时刻隐藏状态

再看你提取正向隐藏状态的代码:

forward_hidden = hn[[x for x in range(0, gru_layers_count * 2, 2)], :, :]

这个写法是完全正确的,它会选取索引0和2的元素,正好对应两层的正向隐藏状态。不过可以给你一个更简洁的写法,效果完全一致:

forward_hidden = hn[::2, :, :]

这里的::2表示从索引0开始,每隔一个元素取一次,刚好拿到所有层的正向隐藏状态。如果要提取反向的,就用hn[1::2, :, :]即可。

备注:内容来源于stack exchange,提问作者ashwani kumar dwivedi

火山引擎 最新活动