Product rule for trace of matrix functions

calculuschain rulederivativesmatricestrace

I am trying to find the gradient of $f(Z_2) = \|A – Zg(Z_1g(Z_2)) \|_F^2$ with respect to $Z_2$ where $g$ function is applied to each matrix element wise such that $i,j$ element of matrix $g(X) = g(X_{ij})$, and $A, Z, Z_1, Z_2 \in R^{n \times n}$. An example of $g$ is an exponential function, tanh function, etc. I am trying to do a gradient descent to find the minimum of $f(Z_2)$ for which I require the gradient. I know $$f(Z_2) = trace((A – Zg(Z_1g(Z_2)))^T(A – Zg(Z_1g(Z_2))))$$ $$\Rightarrow \frac{\partial f(Z_2)}{\partial Z_2} = \frac{\partial trace(-A^TZg(Z_1g(Z_2)) – (Zg(Z_1g(Z_2)))^TA + (Zg(Z_1g(Z_2)))^T(Zg(Z_1g(Z_2)))) }{\partial Z_2}$$
I am not sure how to use the chain rule to find the above gradient.

Edit1 : I tried to solve a smaller problem by taking $f(Z_2) = \|A – g(Z_2) \|_F^2$ . In this case
\begin{align*}
\frac{\partial f(Z_2)}{\partial Z_2} &= \frac{\partial }{\partial Z_2}trace(A^TA -g(Z_2)^TA – A^Tg(Z_2) + g(Z_2)^Tg(Z_2)) \\
& = -A^Tg'(Z_2) -g'(Z_2)^TA + g'(Z_2)^Tg(Z_2) + g(Z_2)^Tg'(Z_2)
\end{align*}

For above derivation, I have used Derivative of trace functions using chain rule. Does the above solution look correct?

Best Answer

Note that for an elementwise function, the differential involves the elementwise/Hadamard product. $$\eqalign{ \def\g{{\large\Gamma}} dg(X) &= g'(X) \odot dX \,=\, \g_X\odot dX \cr }$$ Define a cascade of variables. $$\eqalign{ X &= g(Z_2) \cr Y &= Z_1X \cr W &= g(Y) \cr V &= ZW \cr U &= V-A \cr f &= U:U \cr }$$ Take the differential of the last variable, and reverse the cascade. $$\eqalign{ df &= 2\,U:dU \cr &= 2\,U:dV \cr &= 2\,U:Z\,dW \cr &= 2\,Z^TU:dW \cr &= 2\,Z^TU:\g_Y\odot dY \cr &= 2\,\g_Y\odot(Z^TU):dY \cr &= 2\,\g_Y\odot(Z^TU):Z_1\,dX \cr &= 2\,Z_1^T(\g_Y\odot(Z^TU)):\g_{Z_2}\odot dZ_2 \cr &= 2\,\g_{Z_2}\odot(Z_1^T(\g_Y\odot(Z^TU))):dZ_2 \cr \frac{\partial f}{\partial Z_2} &= 2\,\g_{Z_2}\odot(Z_1^T(\g_Y\odot(Z^TU))) \cr &= 2\,g'(Z_2)\odot\Big(Z_1^T\big(g'(Y)\odot(Z^TU)\big)\Big) \cr \cr }$$ NB:   A colon denotes the trace/Frobenius product, i.e. $\,\,A:B = {\rm Tr}(A^TB)$