You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch中DCGAN在Google Colab GPU训练迭代耗时过长求助

排查DCGAN训练迭代耗时过长的关键问题

我帮你梳理了代码里几个严重影响训练速度的核心问题,按优先级逐一说明:

1. 完全手动的低效数据加载流程

这是拖慢训练的头号元凶!你现在用纯Python循环逐个读取图片、处理、拼接,完全没利用PyTorch的并行加载能力:

  • 每次迭代都在CPU上串行处理64张图片,cv2.imread+转置+reshape+列表append的过程非常耗时,没有多线程加速
  • 手动用np.concatenate拼接张量再转PyTorch张量,中间的CPU-GPU数据传输也没有优化

修复方案
用PyTorch的DatasetDataLoader重构数据加载逻辑,开启多线程并行加载:

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, img_dir, crop_dir):
        self.img_dir = img_dir
        self.crop_dir = crop_dir
        self.img_list = os.listdir(img_dir)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        # 处理训练图:转RGB+归一化+调整为PyTorch默认(C,H,W)格式
        train_img = cv2.imread(os.path.join(self.img_dir, img_name))
        train_img = cv2.cvtColor(train_img, cv2.COLOR_BGR2RGB) / 255.0
        train_img = torch.tensor(train_img).permute(2, 0, 1)

        # 处理采样图:增加通道维度
        sample_img = cv2.imread(os.path.join(self.crop_dir, img_name), 0) / 255.0
        sample_img = torch.tensor(sample_img).unsqueeze(0)
        
        return train_img, sample_img

# 初始化数据集和DataLoader,num_workers根据Colab资源调整
dataset = CustomDataset(path, path2)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

之后训练循环直接遍历dataloader即可,它会自动完成批量加载、并行处理,还能加速CPU-GPU数据传输。

2. 生成器训练逻辑错误(既影响效率也影响效果)

你在训练生成器时写了:

D_G_z = dis(G_z.detach()).view(-1)

这里的G_z.detach()会切断生成器的梯度传播链,导致生成器完全无法得到更新!不仅训练无效,还会浪费计算资源做无用功。正确写法应该是:

# 训练生成器时重新生成G_z(或保留之前的G_z但不要detach)
G_z = Gen(sample_image)
D_G_z = dis(G_z).view(-1)
label.fill_(real_label)
error_gen = GAN_loss(D_G_z, label)
error_gen.backward()
G_optimizer.step()

另外,训练判别器时不需要同时调用Gen.zero_grad(),只需要清零判别器的梯度即可,生成器的梯度清零放在生成器训练步骤前就行,减少冗余操作。

3. 冗余操作与过时API

  • Variable在PyTorch 0.4.0之后已经废弃,直接用torch.tensor即可,不用再包裹Variable
  • 图片处理的.T转置换成PyTorch的permute,更高效且符合框架操作习惯
  • assert语句移到数据集的__getitem__里,只在初始化时检查一次,不用每次迭代都执行

4. 其他小优化点

  • 提前定义real_labelfake_label为张量,比如real_label = torch.tensor(1.0, device=device),不用每次循环调用torch.full
  • 减少训练时的打印频率,比如每10次迭代打印一次,降低IO开销
  • !nvidia-smi确认Colab确实在使用GPU,避免不小心切换到CPU运行

按照这些方案修改后,单次迭代耗时应该能降到几秒以内,训练效率会有质的提升。

内容的提问来源于stack exchange,提问作者TechVision

火山引擎 最新活动