You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

关于交叉熵损失(CELoss)实现及参数含义的技术咨询

Understanding CELoss Inputs and Implementing Forward/Backward Passes

Let's break down your questions step by step, then walk through implementing the forward and backward passes correctly.

Clarifying Inputs: What do x and x_y mean?

Looking at the code comments and standard ML conventions for cross-entropy loss:

  • x is the model's raw logit outputs (not normalized probabilities) for a batch, shaped (batch_size, num_classes) (10 classes in your case).
  • y is the target class labels for each sample in the batch, shaped (batch_size,), where each value is an integer representing the correct class (e.g., 0 to 9 for 10 classes).

Your course's formula:

CELoss(x,y) = - log(exp(x_y)/sumₖexp(xₖ))

is exactly the single-sample cross-entropy loss when using logits, and it does align with Wikipedia's definition—here's why:

  • Wikipedia's cross-entropy is H(p, q) = -Σpᵢ log(qᵢ), where p is the true distribution (a one-hot vector with 1 at the target class y, 0 elsewhere) and q is the predicted probability distribution (softmax of x).
  • For a single sample, this simplifies to -log(q_y) (since only the target class has a non-zero value in p). And q_y = exp(x_y)/Σₖexp(xₖ) (the softmax output for the target class), so substituting gives your course's formula.

x_y specifically refers to the logit value in x corresponding to the target class y for a given sample. For example, if a sample's target label y is 3, x_y is the value at x[sample_index, 3].

Implementing the Forward Pass

To avoid numerical overflow (a common issue when computing exp(x) for large logits), we use the log-sum-exp trick to stabilize calculations. Here's the completed forward method:

import numpy as np

class CELoss(object):
    @staticmethod
    def forward(x, y):
        assert len(x.shape) == 2 # x is batch of predictions (batch_size, 10)
        assert len(y.shape) == 1 # y is batch of target labels (batch_size,)
        batch_size = x.shape[0]
        
        # Extract logits corresponding to target labels for each sample
        target_logits = x[np.arange(batch_size), y]
        
        # Compute log-sum-exp to avoid numerical overflow
        max_logits = np.max(x, axis=1, keepdims=True)
        log_sum_exp = max_logits + np.log(np.sum(np.exp(x - max_logits), axis=1))
        
        # Calculate per-sample loss and average over the batch
        per_sample_loss = -target_logits + log_sum_exp
        avg_loss = np.mean(per_sample_loss)
        
        return avg_loss

Implementing the Backward Pass

The gradient of the cross-entropy loss with respect to the logits x follows directly from the softmax derivative. For each sample:

  • The gradient for non-target classes is equal to the softmax probability of that class.
  • The gradient for the target class is (softmax probability of target class - 1).
  • We scale by dout (upstream gradient) and divide by batch size (since we averaged the loss in the forward pass).

Here's the completed backward method:

@staticmethod
    def backward(x, y, dout):
        assert len(x.shape) == 2
        assert len(y.shape) == 1
        batch_size, num_classes = x.shape
        
        # Compute stable softmax of logits
        max_logits = np.max(x, axis=1, keepdims=True)
        exp_logits = np.exp(x - max_logits)
        softmax = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
        
        # Create one-hot encoding of target labels
        one_hot_targets = np.zeros_like(x)
        one_hot_targets[np.arange(batch_size), y] = 1
        
        # Calculate gradient with respect to logits
        dx = (softmax - one_hot_targets) * dout / batch_size
        
        dy = 0.0 # no useful gradient for y, just set it to zero
        return dx, dy

内容的提问来源于stack exchange,提问作者spadel

火山引擎 最新活动