You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何高效将3D NumPy布尔数组中连续True替换为对应连续组长度

高效替换3D NumPy布尔数组中连续True为其所在段的长度

你需要将布尔数组中每个连续的True序列替换为该序列的长度,同时希望避免Python循环以保证效率。你的思路(正向+反向累计求和再合并)其实是完全可行的,而且是纯向量化的高效方案,只是可以优化代码结构并把逻辑解释得更清楚。下面是更简洁的实现和详细说明:

核心思路

对于每个连续True段里的任意位置:

  1. 正向累计:计算从段起始到当前位置的步数(得到类似1,2,3,...,n的序列)
  2. 反向累计:计算从当前位置到段结束的步数(得到类似n,n-1,...,1的序列)
  3. 合并两个结果:正向值 + 反向值 - 1,就能得到该段的总长度(因为当前位置在正向和反向计算中各被统计了一次)

向量化实现代码

import numpy as np

def replace_true_with_segment_length(arr, axis=2):
    # 正向累计:计算每个True位置到段开头的步数,遇到False重置
    forward = np.cumsum(arr, axis=axis)
    reset_mask = ~arr
    forward_offset = np.maximum.accumulate(forward * reset_mask, axis=axis)
    forward = forward - forward_offset

    # 反向累计:计算每个True位置到段结尾的步数,遇到False重置
    reversed_arr = np.flip(arr, axis=axis)
    backward = np.cumsum(reversed_arr, axis=axis)
    backward_reset_mask = ~reversed_arr
    backward_offset = np.maximum.accumulate(backward * backward_reset_mask, axis=axis)
    backward = backward - backward_offset
    backward = np.flip(backward, axis=axis)

    # 合并结果:True位置替换为段长度,False位置保持0
    result = forward + backward - arr.astype(int)
    return result

代码细节解释

  1. 正向累计部分
    • np.cumsum(arr, axis=axis):沿指定轴累计求和,False对应0,True对应1,连续True会生成递增序列,但遇到False后不会自动重置
    • forward_offset:通过np.maximum.accumulate记录每次遇到False时的累计值,后续的累计结果减去这个偏移量,就实现了"遇到False重置计数"的效果
  2. 反向累计部分
    • 先翻转数组,复用正向累计的逻辑,再把结果翻转回原顺序,得到每个位置到段结尾的步数
  3. 合并结果
    • 对于True位置,forward + backward -1就是段的总长度;对于False位置,forwardbackward都是0,减去0后仍为0,完全符合需求

验证示例

用你提供的一维切片测试:

# 你的测试切片
data = np.array([False]*356 + [True] + [False]*4 + [True] + [False]*186 + [True,True] + [False]*3 + [True]*8 + [False]*19 + [True,True] + [False]*6 + [True]*6 + [False]*12 + [True,True] + [False]*3 + [True] + [False]*3 + [True]*3 + [False]*2 + [True,True] + [False]*2 + [True,True] + [False]*2 + [True] + [False]*1 + [True]*6 + [False]*4 + [True]*3)

result = replace_true_with_segment_length(data.reshape(1,1,365)).flatten()
# 输出结果与你期望的完全一致
print(result)

这个方法完全基于NumPy的向量化操作,没有任何Python循环,对于(61,77,365)规模的数组,计算效率远高于循环实现,能轻松处理你的需求。

内容的提问来源于stack exchange,提问作者fandango_jones

火山引擎 最新活动