[Math] a confusion about the matrix chain rule

chain rulematricesmatrix-calculusvectors

I have the following equation:
$$\begin{align}Wx+b &= O\\
L &=f(O)\end{align}$$

where $W$ is a matrix of $N\times N$, $x, b, O$ are column vector of size $N\times 1$, $L$ is a scalar.

I want to calculate $\frac{\partial{L}}{\partial{W}}$, which will be a matrix of size $N\times N$. According to chain rule, we also have $\frac{\partial{L}}{\partial{W}} = \frac{\partial{L}}{\partial{O}}\frac{\partial{O}}{\partial{W}}$(I am not so sure about this equation). It is easy to calculate $\frac{\partial{L}}{\partial{O}}$, which is of size $N\times 1$. But the $\frac{\partial{O}}{\partial{W}}$ is 3-dimensional matrix of $N\times N \times N$, if I understand it correctly.

My question is, is the chain rule I give right? If it is right, how to multiply 2-dimensional matrix with 3-dimensional matrix? If I am wrong, please correct me.

Best Answer

I always feel insecure about these matrix derivatives, so I write everything in terms of components and expand the matrix products: $$ O_{i}= \sum_j W_{ij} \,x_{j} +b_{i} $$ Then, for any indices $k,j$: $$ \frac{\partial O_{i}}{\partial W_{kj} } = \begin{cases} x_j & i=k \\ 0 & i\ne k \end{cases} $$

Now you can compute $\frac{\partial L}{\partial W}$ by applying the chain rule over all components $O_i$ of $f$: $$ \frac{\partial L}{\partial W_{kj} } = \sum_i \frac{\partial L}{\partial O_{i} }\frac{\partial O_{i}}{\partial W_{kj} }=\frac{\partial L}{\partial O_{k} } x_j $$ since the terms with $i\ne k$ vanish.

And that is... but maybe you want to rewrite this expression as a nice matrix product. Then, consider the vector $g$ with $g_k=\frac{\partial L}{\partial O_{k}} $ and the matrix of derivatives $A$ with $A_{kj}=\frac{\partial L}{\partial W_{kj} }$ yielding $A=g\cdot x^\top$ so with some abuse of notation you can write: $$ \frac{\partial L}{\partial W } =g\cdot x^\top $$