PyTorch中DCGAN在Google Colab GPU训练迭代耗时过长求助
排查DCGAN训练迭代耗时过长的关键问题
我帮你梳理了代码里几个严重影响训练速度的核心问题,按优先级逐一说明:
1. 完全手动的低效数据加载流程
这是拖慢训练的头号元凶!你现在用纯Python循环逐个读取图片、处理、拼接,完全没利用PyTorch的并行加载能力:
- 每次迭代都在CPU上串行处理64张图片,
cv2.imread+转置+reshape+列表append的过程非常耗时,没有多线程加速 - 手动用
np.concatenate拼接张量再转PyTorch张量,中间的CPU-GPU数据传输也没有优化
修复方案:
用PyTorch的Dataset和DataLoader重构数据加载逻辑,开启多线程并行加载:
from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms class CustomDataset(Dataset): def __init__(self, img_dir, crop_dir): self.img_dir = img_dir self.crop_dir = crop_dir self.img_list = os.listdir(img_dir) def __len__(self): return len(self.img_list) def __getitem__(self, idx): img_name = self.img_list[idx] # 处理训练图:转RGB+归一化+调整为PyTorch默认(C,H,W)格式 train_img = cv2.imread(os.path.join(self.img_dir, img_name)) train_img = cv2.cvtColor(train_img, cv2.COLOR_BGR2RGB) / 255.0 train_img = torch.tensor(train_img).permute(2, 0, 1) # 处理采样图:增加通道维度 sample_img = cv2.imread(os.path.join(self.crop_dir, img_name), 0) / 255.0 sample_img = torch.tensor(sample_img).unsqueeze(0) return train_img, sample_img # 初始化数据集和DataLoader,num_workers根据Colab资源调整 dataset = CustomDataset(path, path2) dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
之后训练循环直接遍历dataloader即可,它会自动完成批量加载、并行处理,还能加速CPU-GPU数据传输。
2. 生成器训练逻辑错误(既影响效率也影响效果)
你在训练生成器时写了:
D_G_z = dis(G_z.detach()).view(-1)
这里的G_z.detach()会切断生成器的梯度传播链,导致生成器完全无法得到更新!不仅训练无效,还会浪费计算资源做无用功。正确写法应该是:
# 训练生成器时重新生成G_z(或保留之前的G_z但不要detach) G_z = Gen(sample_image) D_G_z = dis(G_z).view(-1) label.fill_(real_label) error_gen = GAN_loss(D_G_z, label) error_gen.backward() G_optimizer.step()
另外,训练判别器时不需要同时调用Gen.zero_grad(),只需要清零判别器的梯度即可,生成器的梯度清零放在生成器训练步骤前就行,减少冗余操作。
3. 冗余操作与过时API
Variable在PyTorch 0.4.0之后已经废弃,直接用torch.tensor即可,不用再包裹Variable- 图片处理的
.T转置换成PyTorch的permute,更高效且符合框架操作习惯 - 把
assert语句移到数据集的__getitem__里,只在初始化时检查一次,不用每次迭代都执行
4. 其他小优化点
- 提前定义
real_label和fake_label为张量,比如real_label = torch.tensor(1.0, device=device),不用每次循环调用torch.full - 减少训练时的打印频率,比如每10次迭代打印一次,降低IO开销
- 用
!nvidia-smi确认Colab确实在使用GPU,避免不小心切换到CPU运行
按照这些方案修改后,单次迭代耗时应该能降到几秒以内,训练效率会有质的提升。
内容的提问来源于stack exchange,提问作者TechVision




