PyTorch自定义DataLoader维度冗余问题:如何匹配默认Loader输出维度?
解决自定义DataLoader输出维度冗余的问题
问题根源在于你创建自定义loader时,没有显式设置batch_size=None:默认的batch_size=1会把BatchSampler返回的每个批次(4096个样本)再包装成一个大小为1的batch,最终导致输出多了一层不必要的维度。
这里有两种简便的修复方法,都不需要子类化DataLoader:
方法一:设置batch_size=None(推荐)
这是PyTorch官方推荐的用法——当使用BatchSampler作为sampler参数时,必须禁用DataLoader默认的batch打包逻辑,避免双重打包。修改后的tensor_loader代码如下:
def tensor_loader(dataset: TensorDataset, batch_size: int): return DataLoader( dataset=dataset, sampler=BatchSampler( sampler=RandomSampler(dataset), # 等价于shuffle=True batch_size=batch_size, drop_last=True ), batch_size=None # 关键:取消默认的batch打包 )
修改后,自定义loader的输出维度会和默认DataLoader完全一致:
assert next(iter(tensor_loader(dataset, 4096)))[0].shape == torch.Size([4096, 10])
方法二:自定义collate_fn挤压冗余维度
如果不想改动batch_size参数,可以通过自定义collate_fn来去除多余的外层维度。这个函数会在DataLoader打包每个batch时自动调用:
def tensor_loader(dataset: TensorDataset, batch_size: int): def collate_fn(batch): # 处理双重打包导致的冗余维度 if isinstance(batch[0], tuple): # 适配TensorDataset的tuple格式 return tuple(torch.cat(sub_batch, dim=0) for sub_batch in zip(*batch)) return torch.cat(batch, dim=0) return DataLoader( dataset=dataset, sampler=BatchSampler( sampler=RandomSampler(dataset), batch_size=batch_size, drop_last=True ), collate_fn=collate_fn )
同样可以实现和默认loader一致的输出维度,不过方法一更简洁且符合官方规范,优先推荐。
内容的提问来源于stack exchange,提问作者philosofool




