[Math] Does the order matter for chain rule

calculusderivativesmatrices

Given $$\frac{\partial J}{\partial z_2}=\delta_1$$
$$z_2 = hW_2+b_2$$
Derive gradients of $J$ with respect to $h$ and $W_2$, where $J \in \mathbb{R}$, $z_2 \in \mathbb{R}^{D_x \times D_y}$, $\delta_1 \in \mathbb{R}^{D_x \times D_y}$,$W_2 \in \mathbb{R}^{H \times D_y}$, $h \in \mathbb{R}^{D_x \times H}$.

Here's the correct solution:
\begin{align*}
&\frac{\partial J}{\partial h}=\frac{\partial J}{\partial z_2} \frac{\partial z_2}{\partial h}=\delta_1 W_2^T\\
&\frac{\partial J}{\partial W_2}=\frac{\partial z_2}{\partial W_2} \frac{\partial J}{\partial z_2}=h^T \delta_1
\end{align*}

The results are obtained by applying chain rule, however chaining in different orders. The change of orders reflect a compromise to meet the dimension requirements of $\frac{\partial J}{\partial W_2}$. It's very annoying that you have to examine the dimension every time. Is there any general rule that can be followed knowing which order to arrange in terms of applying chain rule without examining the dimension?

Best Answer

Using differentials is less error-prone (for me) than the chain rule because (algebraically) differentials act like normal vectors and matrices.

Applying the differential technique to your question $$\eqalign{ z_2 &= hW_2 + b_2 \cr dz_2 &= dh\,W_2 + h\,dW_2 \cr\cr dJ &= \delta_1:dz_2 \cr &= \delta_1:(dh\,W_2 + h\,dW_2) \cr &= \delta_1:dh\,W_2 + \delta_1:h\,dW_2 \cr &= \delta_1\,W_2^T:dh + h^T\delta_1:dW_2 \cr \cr }$$ Setting $dW_2=0$ yields the gradient wrt $h$ $$\eqalign{ \frac{\partial J}{\partial h} &= \delta_1\,W_2^T \cr\cr }$$ And setting $dh=0$ yields the gradient wrt $W_2$

$$\eqalign{ \frac{\partial J}{\partial W_2} &= h^T\delta_1 \cr\cr }$$ In the above, a colon denotes the double-dot (aka Frobenius) product, which can be defined in terms of the trace as$$A:B=\operatorname{tr}(A^TB)$$The properties of the trace give rise to some useful rules for rearranging the arguments $$\eqalign{ A:BC &= BC:A \cr &= A^T:(BC)^T \cr &= B^TA:C \cr &= AC^T:B \cr &= etc \cr }$$ Something that will help you spot errors using this technique is that both arguments in a Frobenius product must have the same shape. In this respect it is similar to the elementwise (aka Hadamard) product. In fact, the Frobenius product is just the sum over all of the elements in the Hadamard product.

Another useful rule is for the differential of a product $$d(A\star B) = dA\star B + A\star dB$$where $\star$ can be the Frobenius, Hadamard, Kronecker, Tensor (aka dyadic), or the standard matrix product. In addition, the Frobenius and Hadamard products are also commutative, which allows you to re-order and collect terms, much like working with scalar quantities.

Finally, to say that $G$ is the gradient of $f(X)$ means that the differential is $$df=G:dX$$and vice versa.