You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

PyTorch自定义DataLoader维度冗余问题:如何匹配默认Loader输出维度?

解决自定义DataLoader输出维度冗余的问题

问题根源在于你创建自定义loader时,没有显式设置batch_size=None:默认的batch_size=1会把BatchSampler返回的每个批次(4096个样本)再包装成一个大小为1的batch,最终导致输出多了一层不必要的维度。

这里有两种简便的修复方法,都不需要子类化DataLoader

方法一:设置batch_size=None(推荐)

这是PyTorch官方推荐的用法——当使用BatchSampler作为sampler参数时,必须禁用DataLoader默认的batch打包逻辑,避免双重打包。修改后的tensor_loader代码如下:

def tensor_loader(dataset: TensorDataset, batch_size: int):
    return DataLoader(
        dataset=dataset,
        sampler=BatchSampler(
            sampler=RandomSampler(dataset),  # 等价于shuffle=True
            batch_size=batch_size,
            drop_last=True
        ),
        batch_size=None  # 关键:取消默认的batch打包
    )

修改后,自定义loader的输出维度会和默认DataLoader完全一致:

assert next(iter(tensor_loader(dataset, 4096)))[0].shape == torch.Size([4096, 10])

方法二:自定义collate_fn挤压冗余维度

如果不想改动batch_size参数,可以通过自定义collate_fn来去除多余的外层维度。这个函数会在DataLoader打包每个batch时自动调用:

def tensor_loader(dataset: TensorDataset, batch_size: int):
    def collate_fn(batch):
        # 处理双重打包导致的冗余维度
        if isinstance(batch[0], tuple):
            # 适配TensorDataset的tuple格式
            return tuple(torch.cat(sub_batch, dim=0) for sub_batch in zip(*batch))
        return torch.cat(batch, dim=0)
    
    return DataLoader(
        dataset=dataset,
        sampler=BatchSampler(
            sampler=RandomSampler(dataset),
            batch_size=batch_size,
            drop_last=True
        ),
        collate_fn=collate_fn
    )

同样可以实现和默认loader一致的输出维度,不过方法一更简洁且符合官方规范,优先推荐。

内容的提问来源于stack exchange,提问作者philosofool

火山引擎 最新活动