You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

Numpy是否存在或即将推出全局保持维度的设置?

如何让NumPy默认保留单例维度?

好问题!目前NumPy并没有官方支持的全局设置(比如你设想的np.setparams(keepdims=True))来让所有函数默认保留单例维度,但我们有几个实用的替代方案,能帮你避免每次手动加参数的麻烦:

方案1:封装常用函数,默认开启keepdims

最直接的办法是自己包装常用的NumPy函数,把keepdims=True设为默认参数。比如针对summean这类常用聚合函数:

import numpy as np

# 封装sum函数,默认keepdims=True
def my_sum(a, axis=None, keepdims=True, **kwargs):
    return np.sum(a, axis=axis, keepdims=keepdims, **kwargs)

# 封装mean函数
def my_mean(a, axis=None, keepdims=True, **kwargs):
    return np.mean(a, axis=axis, keepdims=keepdims, **kwargs)

# 测试效果
x = np.random.randint(10, size=(5, 10))
print(my_sum(x, axis=0).shape)  # 输出 (1, 10)

这种方式简单可控,不会影响NumPy原生函数的默认行为,适合只用到少数几个函数的场景。

方案2:自定义数组子类,强制保留维度

如果需要连切片操作(比如x[:,0])都自动保留维度,可以自定义一个继承自np.ndarray的子类,重写索引和常用方法的行为:

import numpy as np

class KeepDimArray(np.ndarray):
    def __new__(cls, input_array):
        # 将输入数组转为自定义子类实例
        return np.asarray(input_array).view(cls)
    
    def __getitem__(self, key):
        result = super().__getitem__(key)
        # 处理切片:如果原维度被索引成单值,补回维度
        if isinstance(key, tuple):
            new_shape = list(self.shape)
            for idx, k in enumerate(key):
                if isinstance(k, int):
                    new_shape[idx] = 1
            return result.reshape(new_shape).view(KeepDimArray)
        elif isinstance(key, int):
            return result.reshape((1,) + self.shape[1:]).view(KeepDimArray)
        return result.view(KeepDimArray)
    
    # 重载sum方法,默认keepdims=True
    def sum(self, axis=None, keepdims=True, **kwargs):
        return super().sum(axis=axis, keepdims=keepdims, **kwargs).view(KeepDimArray)

# 测试
x = KeepDimArray(np.random.randint(10, size=(5, 10)))
print(x.sum(axis=0).shape)  # 输出 (1, 10)
print(x[:, 0].shape)        # 输出 (5, 1)

需要注意的是,这种方法需要逐个重载你用到的NumPy方法(比如maxmin等),而且部分第三方函数可能无法完美适配自定义数组类型,适合对维度一致性要求极高的场景。

方案3:用装饰器修改原生函数(谨慎使用)

如果你想直接修改NumPy原生函数的默认参数,可以用装饰器实现,但强烈建议只在局部代码中使用,避免影响其他依赖原生行为的代码:

import numpy as np
from functools import wraps

def keepdims_default(func):
    @wraps(func)
    def wrapper(*args, keepdims=True, **kwargs):
        return func(*args, keepdims=keepdims, **kwargs)
    return wrapper

# 给sum和mean加上默认keepdims=True
np.sum = keepdims_default(np.sum)
np.mean = keepdims_default(np.mean)

# 测试
x = np.random.randint(10, size=(5, 10))
print(np.sum(x, axis=0).shape)  # 输出 (1, 10)

关于未来的全局设置

目前NumPy的官方文档和开发路线图中,还没有明确提到会推出全局keepdims设置的计划。如果这个需求对你很重要,可以关注NumPy的GitHub仓库,看看有没有相关的功能提议或正在开发的特性。

内容的提问来源于stack exchange,提问作者Jiageng

火山引擎 最新活动