[Math] Using Chain Rule in Matrix Differentiation

derivativesmatricesmatrix-calculusneural networks

I have the following parameters and their respective dimension:

$X:2\times 1$, $W_1:7\times2$, $W_2:1\times7$, $B_1: 7\times 1$ and $B_2:1\times1$, with the following formulation:

$Y=W_2H+B_2$ where $H=\verb+ReLU+(W_1X+B_1)$, the rectified linear unit applied element-wise ($\verb+ReLU+(x)=\max(0,x)$). I want to compute $$\frac{\partial Y}{\partial W_1},$$
by using the chain rule. Hence, I compute

$$\frac{\partial Y}{\partial W_1}=\frac{\partial Y}{\partial H}\cdot\frac{\partial H}{\partial W_1}$$

which is equal to $$W_2\cdot\frac{\partial H}{\partial W_1}.$$

My problem is computing $\frac{\partial H}{\partial W_1}$. I take out $X^T$ from this, by using chain rule, but then it doesn't match the dimesnion for multiplication. How do we go about taking the derivative of $H$ w.r.t. $W_1$, which is a $7\times 2$ matrix?

I was told that the final result has dimesnion $7\times2$, but no matter how I arrange things, I can't come up with the correct result.

Best Answer

Rather than the chain rule, let's tackle the problem using differentials.

Let's use the convention that an upppercase letter is a matrix, lowercase is a column vector, and a greek letter is a scalar. Now let's define some variables and their differentials $$\eqalign{ z &= W_1x+b_1 &\implies dz=dW_1\,x \cr h &= {\rm relu}(z) &\implies dh={\rm step}(z)\odot dz = s\odot dz \cr }$$ where ${\rm step}(z)$ is the Heaviside step function. Both the relu and step functions are applied elementwise to their arguments.

Now take the differential and gradient of the function. $$\eqalign{ \phi &= w_2^Th + \beta_2 \cr &= w_2:h + \beta_2 \cr \cr d\phi &= w_2:dh \cr &= w_2:(s\odot dz) \cr &= (s\odot w_2):dz \cr &= (s\odot w_2):dW_1\,x \cr &= (s\odot w_2)x^T:dW_1 \cr \cr \frac{\partial\phi}{\partial W_1} &= (w_2\odot s)x^T \cr &= \Big(w_2\odot{\rm step}(W_1x+b_1)\Big)x^T \cr\cr }$$ In the above, I used the notations $$\eqalign{ &A:B = {\rm tr}(A^TB) \cr &A\odot B \cr }$$ for the trace/Frobenius and elementwise/Hadamard products, respectively.