通过计算图求梯度下降中各偏导的推导

Neural Networks and Deep Learning 课程的 Logistic Regression Gradient Descent 一节以逻辑回归为例,介绍了使用计算图(Computation Graph)求梯度下降中各偏导的方法,但没有给出具体的推导过程。

例子中模型为:

(1)   \begin{equation*} z = w^Tx + b \end{equation*}

预测为:

(2)   \begin{equation*} \hat y = a = \sigma(z) \end{equation*}

其中 \sigma (z) 为 Sigmoid 函数:

(3)   \begin{equation*} \sigma(z) = \frac{1}{1 + e^{-z}}  \end{equation*}

代价函数为:

(4)   \begin{equation*} L(a, y) = -(ylog(a) + (1 - y)log(1 - a)) \end{equation*}

假设只有两个特征 x_{1}x_{2},则:

(5)   \begin{equation*} w^T =  \begin{bmatrix} w_{1} \ w_{2} \end{bmatrix}  \end{equation*}

运算图如图1所示:

图 1

图 1

反向计算各偏导的过程如下:

  • 首先求得 \frac{\partial L}{\partial a} 如下:

(6)   \begin{equation*} \frac{\partial L}{\partial a} = - \frac{y}{a} + \frac{1 - y}{1 - a} \end{equation*}

  • 然后可以由链式法则求得 \frac{\partial L}{\partial z} 如下:

(7)   \begin{equation*} \frac{\partial L}{\partial z} = \frac{\partial L}{\partial a} \cdot \frac{da}{dz} \end{equation*}

其中,a = \sigma(z) 是 Sigmoid 函数,有:

(8)   \begin{equation*} \frac{d\sigma(z)}{dz} = \sigma(z)(1 - \sigma(z)) \end{equation*}

将式 (6)、(8) 带入式 (7),得:

(9)   \begin{equation*} \frac{\partial L}{\partial z} = (- \frac{y}{a} + \frac{1 - y}{1 - a}) \cdot a(1 - a) \ = -y(1 - a) + a(1 - y) \ = -y + a \end{equation*}

  • 最后求得 \frac{\partial L}{\partial w_{1}}\frac{\partial L}{\partial w_{2}}\frac{\partial L}{\partial b} 如下:

(10)   \begin{equation*} \frac{\partial L}{\partial w_{1}} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w_{1}} = \frac{\partial L}{\partial z} \cdot x_{1} \end{equation*}

(11)   \begin{equation*} \frac{\partial L}{\partial w_{2}} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w_{2}} = \frac{\partial L}{\partial z} \cdot x_{2} \end{equation*}

(12)   \begin{equation*} \frac{\partial L}{\partial b} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial b} = \frac{\partial L}{\partial z} \end{equation*}

这里 \frac{\partial L}{\partial z} 不再展开。实际应用中,在由式 (9) 求得 \frac{\partial L}{\partial z} 的值之后,就可以直接带入式 (10)、(11)、(12) 进行计算。