You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何高效使用PyTorch的autograd对张量进行求导?

Efficient Tensor Gradient Calculation in PyTorch (Ditching Slow Loops)

Hey there, I get exactly why you’re frustrated—looping through every element of a large tensor to compute gradients is painfully slow, and that RuntimeError from your Jacobian diagonal attempt is just adding to the headache. Let’s break down what’s going wrong and fix this with proper, efficient PyTorch techniques.

First, Let’s Clarify Your Goal

You’re trying to compute the gradient of each output element with respect to its corresponding input element (so the diagonal of the Jacobian matrix, right? Where each input x[i] maps to outputs y[i, :]). The loop approach works but doesn’t scale—PyTorch has better tools for this.

Solution 1: Use torch.autograd.functional.jacobian (Simple & Clean)

PyTorch’s built-in jacobian function can compute the entire Jacobian matrix in one go, and then you just extract the diagonal elements you need. No loops, no fuss.

Here’s how to adapt your code:

import torch
import torch.nn as nn

class net_x(nn.Module):
    def __init__(self):
        super(net_x, self).__init__()
        self.fc1 = nn.Linear(1, 20)
        self.fc2 = nn.Linear(20, 20)
        self.out = nn.Linear(20, 4)
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return self.out(x)

nx = net_x()
# Let's use a batch of 10 inputs (each shape (1,))
x = torch.rand(10, 1, requires_grad=True)

# Compute the full Jacobian: shape (10, 4, 10, 1)
# This captures dy[i,j]/dx[k,0] for all i,j,k
jacobian = torch.autograd.functional.jacobian(nx, x)

# Extract the diagonal where input index matches batch index (dy[i,j]/dx[i,0])
diag_gradients = torch.diagonal(jacobian, dim1=0, dim2=2).squeeze(-1)
# Result shape: (10, 4) — each row has gradients for the 4 outputs of the i-th input
print(diag_gradients.shape)

Solution 2: Optimized Diagonal Gradients (Memory-Efficient for Large Tensors)

If you don’t need the full Jacobian (which can get huge for big tensors), you can use torch.autograd.grad with custom grad_outputs to compute only the diagonal elements. This saves memory by avoiding unnecessary computations.

Here’s how:

nx = net_x()
x = torch.rand(10, 1, requires_grad=True)
y = nx(x)  # Shape (10, 4)

# Flatten outputs and inputs to simplify indexing
y_flat = y.flatten()
x_flat = x.flatten()

# Create an identity matrix as grad_outputs: this tells PyTorch to compute each dy_flat[i]/dx_flat[k]
grad_outputs = torch.eye(y_flat.size(), device=y_flat.device)

# Compute all gradients in one pass
grads = torch.autograd.grad(y_flat, x_flat, grad_outputs=grad_outputs)[0]

# Extract the diagonal where each output maps to its corresponding input
# Each input has 4 outputs, so we map output index i to input index i//4
diag_indices = torch.arange(y_flat.size()) // y.shape[1]
diag_grads = grads[torch.arange(y_flat.size()), diag_indices]

# Reshape back to match the original output shape
diag_grads = diag_grads.view(y.shape)
print(diag_grads)

Solution 3: Use torch.vmap (PyTorch 2.0+ — Fastest Option)

If you’re on PyTorch 2.0 or newer, torch.vmap (vectorized map) is the way to go. It automatically batches your gradient computation across inputs, eliminating loops entirely and leveraging PyTorch’s latest optimizations.

import torch
import torch.nn as nn
from torch.func import vmap, functional_call

class net_x(nn.Module):
    def __init__(self):
        super(net_x, self).__init__()
        self.fc1 = nn.Linear(1, 20)
        self.fc2 = nn.Linear(20, 20)
        self.out = nn.Linear(20, 4)
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return self.out(x)

nx = net_x()
params = dict(nx.named_parameters())
x = torch.rand(10, 1)  # Batch of 10 inputs

# Define a function to compute gradients for a single input element
def single_input_grad(params, x_single):
    return torch.autograd.functional.jacobian(lambda x: functional_call(nx, params, x), x_single)

# Use vmap to apply this function across the entire batch
batch_jacobian = vmap(single_input_grad)(params, x)

# Extract the gradients (shape becomes (10,4) after squeezing)
diag_gradients = batch_jacobian.squeeze(-1)
print(diag_gradients.shape)

Fixing Your Original Test Code

Your earlier attempt failed because the input reshaping was off, and you didn’t provide proper grad_outputs to handle non-scalar outputs. Here’s the corrected version using torch.autograd.grad:

nx = net_x()
x = torch.rand(10, requires_grad=True).reshape(10, 1)  # Correct input shape (10,1)
y = nx(x)  # Shape (10,4)

# Flatten outputs and inputs
y_flat = y.flatten()
x_flat = x.flatten()

# Identity matrix grad_outputs to get individual gradients
grad_outputs = torch.eye(y_flat.size(), device=y_flat.device)
grads = torch.autograd.grad(y_flat, x_flat, grad_outputs=grad_outputs)[0]

# Map output indices to their corresponding input indices
diag_indices = torch.arange(y_flat.size()) // y.shape[1]
diag_grads = grads[torch.arange(y_flat.size()), diag_indices].view(y.shape)

print(diag_grads)

Why These Methods Are Way Faster

  • Batch Processing: All these techniques compute gradients in a single pass, avoiding the overhead of repeated computation graph traversals that come with loops.
  • Optimized Operations: PyTorch’s backend (CPU/GPU) is optimized for batch operations, so you get much better utilization of hardware resources.
  • No Redundant Work: Loops recalculate parts of the model for each element, while batch methods reuse the computation graph efficiently.

内容的提问来源于stack exchange,提问作者Penguin

火山引擎 最新活动