You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何在GPU上执行PyWavelets计算以加速分类器训练?

如何在GPU上执行PyWavelets计算以加速分类器训练?

嘿,我太懂你现在的糟心情况了——把数据在GPU和CPU之间来回倒腾,再加上PyWavelets在CPU上啃计算,CPU肯定堵得死死的,训练速度能快才怪!问题核心就出在你当前的代码逻辑里:每次都把GPU张量拉到CPU转成numpy数组,算完小波变换再塞回GPU,这中间的数据传输开销CPU计算负载就是拖慢训练的元凶。

PyWavelets本身是不支持GPU张量操作的,所以咱们得换个思路:用PyTorch原生的GPU兼容操作来实现你要的Haar小波变换,让所有计算都在GPU上完成,彻底跟CPU说拜拜。

刚好你用的是Haar小波,它的变换逻辑特别简单,咱们直接用卷积来模拟就行,完全适配GPU加速。下面是修改后的代码,我给你掰碎了讲:

import torch
import torch.nn as nn
import torch.nn.functional as F

class WaveletLayer(nn.Module):
    def __init__(self, in_channels):
        super(WaveletLayer, self).__init__()
        # 定义Haar小波的四个卷积核,对应LL、LH、HL、HH四个分量
        haar_ll = torch.tensor([[1.0, 1.0], [1.0, 1.0]]) / 2.0  # 低通分量
        haar_lh = torch.tensor([[1.0, 1.0], [-1.0, -1.0]]) / 2.0  # 水平高通
        haar_hl = torch.tensor([[1.0, -1.0], [1.0, -1.0]]) / 2.0  # 垂直高通
        haar_hh = torch.tensor([[1.0, -1.0], [-1.0, 1.0]]) / 2.0  # 对角高通
        
        # 把核拼接成适合分组卷积的形状,每个输入通道对应一组核
        kernels = torch.stack([haar_ll, haar_lh, haar_hl, haar_hh], dim=0)
        kernels = kernels.repeat(in_channels, 1, 1, 1)
        
        # 注册为固定参数(不需要学习),会自动随模型迁移到GPU
        self.register_buffer('kernels', kernels)
        self.in_channels = in_channels

    def forward(self, x):
        # 用分组卷积实现单通道独立小波变换,全程GPU计算
        output = F.conv2d(
            x, 
            self.kernels, 
            groups=self.in_channels,  # 每个输入通道单独处理,和原逻辑一致
            stride=2,  # 对应小波变换的下采样,输出尺寸为原尺寸的1/2
            padding=0
        )
        # 输出形状和原代码拼接后的结果完全一致,直接传给ResNet块即可
        return output

为啥这个版本能解决问题?

  1. 彻底告别CPU-GPU数据传输:所有操作都是PyTorch原生的GPU张量运算,没有cpu().numpy()torch.from_numpy()这类拖后腿的步骤。
  2. 分组卷积适配原逻辑:分组卷积让每个输入通道单独做小波变换,和你原来循环处理每个通道的逻辑完全对齐,但效率是GPU级别的并行计算。
  3. 输出兼容原有流程:最终输出的张量通道数是4*in_channels,和你原代码拼接LL/LH/HL/HH后的通道数完全一致,后面的ResNet块不需要做任何修改。

怎么用?

初始化的时候传入输入通道数就行,比如你的输入是3通道RGB图,就这么写:

wavelet_layer = WaveletLayer(in_channels=3).to('cuda')

之后的训练流程和原来一模一样,但你会发现CPU使用率立刻降下来,训练速度会有明显提升——终于能把GPU的算力用满了!

额外小提示

如果以后你想用Haar之外的小波(比如db系列),可以找找基于PyTorch的GPU加速小波库,或者自己推导对应小波的卷积核。不过就你当前的需求来说,这个Haar卷积实现完全够用,而且是最轻量化的GPU加速方案。

备注:内容来源于stack exchange,提问作者Aryan Raj

火山引擎 最新活动