如何运用Chain Rule(链式法则)推导梯度∂/∂W ||y - Wz||²?
我最近在研究这个梯度的计算:
$$\frac{\partial}{\partial W} || y- W z||^2$$
其中变量的维度是:$W \in \mathbb{R}^{m \times n}, z \in \mathbb{R}^n, y \in \mathbb{R}^m$。
其实我知道可以直接展开计算得到结果,过程是这样的:
$$
\begin{align*}
\frac{\partial}{\partial W} || y- W z||^2 &= \frac{\partial}{\partial W} ( y- W z)^T ( y- W z) \
&= \frac{\partial}{\partial W} \left[y^Ty -2 y^T W z + (W z)^T (W z) \right] \
&= -2 y z^T + 2 W z z^T \
&= 2(W z - y) z^T
\end{align*}
$$
不过我特意想用链式法则来推导,这样能更清晰地理解复合函数求导的逻辑,而不是直接展开硬算。我们先定义中间变量$h := Wz$,接下来分三步完成推导:
第一步:计算外层函数对h的梯度
首先看外层函数$f(h) = ||y - h||^2$,我们先求它对h的梯度:
$$
\begin{align*}
\nabla_h f(h) &= \frac{\partial}{\partial h} ||y - h||^2 \
&= \frac{\partial}{\partial h} (y - h)^T(y - h) \
&= \frac{\partial}{\partial h} \left(y^T y - y^T h - h^T y + h^T h\right) \
&= 0 - y - y + 2h \
&= 2(h - y)
\end{align*}
$$
这里要注意,$y^T h$和$h^T y$是同一个标量,对h求导的结果都是y,两者相加就是2y,前面带负号所以得到-2y;再加上$h^T h$对h求导的结果2h,最终得到这个m维的列向量。
第二步:计算中间变量h对W的导数
接下来求h对W的导数:h是$Wz$,对于W中任意一个元素$W_{ij}$(第i行第j列),h的第i个元素$h_i = \sum_{k=1}^n W_{ik} z_k$,所以$\frac{\partial h_i}{\partial W_{ij}} = z_j$;而h的其他元素$h_k (k \neq i)$对$W_{ij}$的导数都是0。
用矩阵导数的实用形式来看,当我们把一个m维列向量(比如刚才得到的$\nabla_h f(h)$)和这个导数结合时,等价于将该列向量右乘$z^T$(n维行向量),这样就能得到一个和W维度一致的m×n矩阵。
第三步:链式法则组合结果
根据链式法则,原梯度等于外层函数对h的梯度与h对W的导数的乘积,代入计算:
$$
\begin{align*}
\frac{\partial}{\partial W} ||y - Wz||^2 &= \nabla_h f(h) \cdot \frac{\partial h}{\partial W} \
&= 2(h - y) z^T
\end{align*}
$$
最后把$h = Wz$代回去,就得到和直接展开完全一致的结果:
$$
\frac{\partial}{\partial W} ||y - Wz||^2 = 2(Wz - y) z^T
$$
备注:内容来源于stack exchange,提问作者Luca9984




