如何高效使用PyTorch的autograd对张量进行求导?
Hey there, I get exactly why you’re frustrated—looping through every element of a large tensor to compute gradients is painfully slow, and that RuntimeError from your Jacobian diagonal attempt is just adding to the headache. Let’s break down what’s going wrong and fix this with proper, efficient PyTorch techniques.
First, Let’s Clarify Your Goal
You’re trying to compute the gradient of each output element with respect to its corresponding input element (so the diagonal of the Jacobian matrix, right? Where each input x[i] maps to outputs y[i, :]). The loop approach works but doesn’t scale—PyTorch has better tools for this.
Solution 1: Use torch.autograd.functional.jacobian (Simple & Clean)
PyTorch’s built-in jacobian function can compute the entire Jacobian matrix in one go, and then you just extract the diagonal elements you need. No loops, no fuss.
Here’s how to adapt your code:
import torch import torch.nn as nn class net_x(nn.Module): def __init__(self): super(net_x, self).__init__() self.fc1 = nn.Linear(1, 20) self.fc2 = nn.Linear(20, 20) self.out = nn.Linear(20, 4) def forward(self, x): x = torch.tanh(self.fc1(x)) x = torch.tanh(self.fc2(x)) return self.out(x) nx = net_x() # Let's use a batch of 10 inputs (each shape (1,)) x = torch.rand(10, 1, requires_grad=True) # Compute the full Jacobian: shape (10, 4, 10, 1) # This captures dy[i,j]/dx[k,0] for all i,j,k jacobian = torch.autograd.functional.jacobian(nx, x) # Extract the diagonal where input index matches batch index (dy[i,j]/dx[i,0]) diag_gradients = torch.diagonal(jacobian, dim1=0, dim2=2).squeeze(-1) # Result shape: (10, 4) — each row has gradients for the 4 outputs of the i-th input print(diag_gradients.shape)
Solution 2: Optimized Diagonal Gradients (Memory-Efficient for Large Tensors)
If you don’t need the full Jacobian (which can get huge for big tensors), you can use torch.autograd.grad with custom grad_outputs to compute only the diagonal elements. This saves memory by avoiding unnecessary computations.
Here’s how:
nx = net_x() x = torch.rand(10, 1, requires_grad=True) y = nx(x) # Shape (10, 4) # Flatten outputs and inputs to simplify indexing y_flat = y.flatten() x_flat = x.flatten() # Create an identity matrix as grad_outputs: this tells PyTorch to compute each dy_flat[i]/dx_flat[k] grad_outputs = torch.eye(y_flat.size(), device=y_flat.device) # Compute all gradients in one pass grads = torch.autograd.grad(y_flat, x_flat, grad_outputs=grad_outputs)[0] # Extract the diagonal where each output maps to its corresponding input # Each input has 4 outputs, so we map output index i to input index i//4 diag_indices = torch.arange(y_flat.size()) // y.shape[1] diag_grads = grads[torch.arange(y_flat.size()), diag_indices] # Reshape back to match the original output shape diag_grads = diag_grads.view(y.shape) print(diag_grads)
Solution 3: Use torch.vmap (PyTorch 2.0+ — Fastest Option)
If you’re on PyTorch 2.0 or newer, torch.vmap (vectorized map) is the way to go. It automatically batches your gradient computation across inputs, eliminating loops entirely and leveraging PyTorch’s latest optimizations.
import torch import torch.nn as nn from torch.func import vmap, functional_call class net_x(nn.Module): def __init__(self): super(net_x, self).__init__() self.fc1 = nn.Linear(1, 20) self.fc2 = nn.Linear(20, 20) self.out = nn.Linear(20, 4) def forward(self, x): x = torch.tanh(self.fc1(x)) x = torch.tanh(self.fc2(x)) return self.out(x) nx = net_x() params = dict(nx.named_parameters()) x = torch.rand(10, 1) # Batch of 10 inputs # Define a function to compute gradients for a single input element def single_input_grad(params, x_single): return torch.autograd.functional.jacobian(lambda x: functional_call(nx, params, x), x_single) # Use vmap to apply this function across the entire batch batch_jacobian = vmap(single_input_grad)(params, x) # Extract the gradients (shape becomes (10,4) after squeezing) diag_gradients = batch_jacobian.squeeze(-1) print(diag_gradients.shape)
Fixing Your Original Test Code
Your earlier attempt failed because the input reshaping was off, and you didn’t provide proper grad_outputs to handle non-scalar outputs. Here’s the corrected version using torch.autograd.grad:
nx = net_x() x = torch.rand(10, requires_grad=True).reshape(10, 1) # Correct input shape (10,1) y = nx(x) # Shape (10,4) # Flatten outputs and inputs y_flat = y.flatten() x_flat = x.flatten() # Identity matrix grad_outputs to get individual gradients grad_outputs = torch.eye(y_flat.size(), device=y_flat.device) grads = torch.autograd.grad(y_flat, x_flat, grad_outputs=grad_outputs)[0] # Map output indices to their corresponding input indices diag_indices = torch.arange(y_flat.size()) // y.shape[1] diag_grads = grads[torch.arange(y_flat.size()), diag_indices].view(y.shape) print(diag_grads)
Why These Methods Are Way Faster
- Batch Processing: All these techniques compute gradients in a single pass, avoiding the overhead of repeated computation graph traversals that come with loops.
- Optimized Operations: PyTorch’s backend (CPU/GPU) is optimized for batch operations, so you get much better utilization of hardware resources.
- No Redundant Work: Loops recalculate parts of the model for each element, while batch methods reuse the computation graph efficiently.
内容的提问来源于stack exchange,提问作者Penguin




