How can the Jacobian of the softmax function be used for calculating the gradient of a linear layer

linear algebramachine learning

I am creating a visual transformer from scratch as an exercise to have a better understanding of machine learning. Within the network it multiplies 2 matrices, then the softmax function is applied, and it is not the last layer, so I am assuming cross entropy would not be used, and multiplied by a constant I am going to refer to as $$C$$ and then more layers are applied afterward. Still, I am only planning on focusing on what I covered.

So take the matrices $Q$ and $K$ where $Q$ is a $3\times2$ matrix and $K$ is a $2\times3$ matrix, and let $A$ be the result of $Q$ and $K$ being multiplied which results in a $3\times3$ matrix.

Where,
$$Q = \begin{pmatrix}q_{1,1} & q_{1,2}\\ q_{2,1} & q_{2,2}\\q_{3,1} & q_{3,2}\end{pmatrix}$$

and

$$K = \begin{pmatrix}k_{1,1} & k_{1,2} & k_{1,3}\\ k_{2,1} & k_{2,2} & k_{2,3}\end{pmatrix}$$

making $$A = \begin{pmatrix}q_{1,1}*k_{1,1} + q_{1,2}*k_{2,1} & q_{1,1}*k_{1,2} + q_{1,2}*k_{2,2} & q_{1,1}*k_{1,3} + q_{1,2}*k_{2,3}\\ q_{2,1}*k_{1,1} + q_{2,2}*k_{2,1} & q_{2,1}*k_{1,2} + q_{2,2}*k_{2,2} & q_{2,1}*k_{1,3} + q_{2,2}*k_{2,3}\\ q_{3,1}*k_{1,1} + q_{3,2}*k_{2,1} & q_{3,1}*k_{1,2} + q_{3,2}*k_{2,2} & q_{3,1}*k_{1,3} + q_{3,2}*k_{2,3}\end{pmatrix} = \begin{pmatrix} a_{1,1} & a_{1,2} & a_{1,3} \\ a_{2,1} & a_{2,2} & a_{2,3} \\ a_{3,1} & a_{3,2} & a_{3,3}\end{pmatrix}$$

Then the softmax function is applied on A which is defined as
$$S(x_{i}) = \frac{e^{x_{i}}}{\sum_{j=0}^n e^{x_{j}}}$$

Afterwards $B$ = $S(A)*C$

So then on the backward pass, with respect to the loss, $L$, Let $Z$ = $B*C$, $\frac{\partial L}{\partial A} = S'(Z)$

where the derivative of the softmax function is given by

$z_{i}(1-z_{i})$ when $i = j$ and
$-z_{i}z_{j}$ when $i != j$ from the $Z$ matrix shown above
making
$$\frac{\partial L}{\partial A} = \begin{pmatrix}
z_{1,1}(1-z_{1,1}) & -z_{1,1}z_{1,2} & \cdots & -z_{1,1}z_{3,3} \\ -z_{1,2}z_{1,1} & z_{1,2}(1-z_{1,2}) & \cdots & -z_{1,2}z_{3,3} \\ \vdots & \vdots & \ddots & \vdots \\ -z_{3,3}z_{1,1} & -z_{3,3}z_{1,2} & \cdots & z_{3,3}(1-z_{3,3})\end{pmatrix}$$

Which is a $9×9$ matrix

What I don't understand is how I can calculate $\frac{\partial A}{\partial Q}$ and $\frac{\partial A}{\partial K}$ to find $\frac{\partial L}{\partial K}$ and $\frac{\partial L}{\partial Q}$ in order to allow for continuous back propagation. From what I understand I have to compute the jacobian of $A$ with respect to $Q$ to find $\frac{\partial A}{\partial Q}$ But I am not entirely sure how to do that to then revert back to the same dimensions as $Q$

Best Answer

Direct answer

You can just calculate the derivatives directly from your formula for $A$. I can write your formula as $$a_{ij} = \sum_{\alpha} q_{i\alpha} k_{\alpha j}$$ so then we have $$\frac{da_{ij}}{dq_{rs}} = \delta_{ir} k_{sj}$$ where $\delta$ is Kronecker delta.

In general, if you have a function from $\mathbb{R}^n \to \mathbb{R}^m$, then the derivative will be a $(n+m)$-dimensional tensor. In case you aren't used to the word "tensor", just think a vector is a 1D block of numbers, a matrix is a 2D block of numbers, and then this derivative will be a $(n+m)$-D block of numbers. That should make sense because as we see in the formula above, to fully understand $\frac{dA}{dQ}$ we need to understand the derivative of any entry in $A$ with respect to any entry in $Q$.

The next logical question would be - how do we use that for back propagation? In this case your real goal is to compute $\frac{dL}{dQ}$ and you know $L$ as a function $L(A(Q,K))$. You've already computed $\frac{dL}{dA}$, which was is a 2D tensor, which matches my formula above because the function that computes $L$ from $A$ has 2D input and 0D output.

The formula you want will be: $$\begin{align} \frac{dL}{dQ_{rs}} &= \sum_{ij} \frac{dL}{dA_{ij}} \frac{dA_{ij}}{dQ_{rs}} \\ &= \sum_{ij} \frac{dL}{dA_{ij}} \delta_{ir} k_{sj} \\ &= \sum_{j} \frac{dL}{dA_{ij}} k_{sj} \end{align}$$ where $\frac{dL}{dA_{ij}}$ is the derivative you already computed in your question statement.

Of course, you can repeat the same steps for $\frac{dL}{dK}$.

Example that might help the intuition?

Say we have a function $f:\mathbb{R} \to \mathbb{R}$ given by $f(x) = g(h(x))$ where $h(x) = (x^2, x^3)$ and $g(a, b) = a \ln(b)$. Then what's $\frac{df}{dx}$?

You probably learned to handle this kind of situation in a multivariable calculus class at some point. We can compute $$\begin{align} \frac{df}{dx} &= \frac{df}{dh_1} \frac{dh_1}{dx} + \frac{df}{dh_2}{dh_2}{dx} \\ &= \ln(h_2(x)) (2x) + \frac{h1(x)}{h2(x)} (3x^2) \\ &= 2x \ln(x^3) + 3x^2 \frac{x^2}{x^3} \\ &= 6x \ln(x) + 3x \end{align}$$ which you could confirm if you want by manually differentiating $f(x) = x^2 \ln(x^3)$.

Can you see how this computation is just a really simple version of the same method I was using for your real problem above? In this case the mappings just go $\mathbb{R} \to \mathbb{R}^2 \to \mathbb{R}$ but you still apply the same idea where $\frac{df}{dx} = \sum_{j} \frac{df}{dh_j} \frac{dh_j}{dx}$.

Further reading

I can get a bunch of good reading material on this topic by Googling chain rule for matrices using tensors. Here are some example hits that look reasonable to me at first glance: 1 and 2 and 3. But you don't need to trust these particular sources; really I think you can learn more yourself now that I've told you the right keywords to look up. Good luck :)


EDIT: Computing $\mathbf{\frac{dS(A)}{dQ}}$

Per discussion in the comments, OP would like me to also write out how to compute $\frac{dS(A)}{dQ}$.

We want the derivative of the map $Q \mapsto S(A)$. We'll compute it via chain rule, using the composition $Q \mapsto A \mapsto S(A)$. We already computed the derivative of the first step, $\frac{dQ}{dA}$. For the second step $\frac{dS(A)}{dA}$ we are differentiating a $3 \times 3$ matrix with respect to a $3 \times 3$ matrix, so we expect a 4 dimensional output tensor which will be $3 \times 3 \times 3 \times 3$.

Using the derivative formula you already computed, we have $$\begin{align} \frac{dS(A)_{ij}}{dA_{rs}} &= \delta_{ir} \delta_{js} a_{ij} - a_{ij} a_{rs}. \end{align}$$ Now we use Chain Rule to compute the final derivative we're looking for: $$\begin{align} \frac{dS(A)_{ij}}{dQ_{cd}} &= \sum_{r} \sum_{s} \frac{dS(A)_{ij}}{dA_{rs}} \frac{dA_{rs}}{dQ_{cd}} \\ &= \sum_{r,s} \left(\delta_{ir} \delta_{js} a_{ij} - a_{ij} a_{rs} \right) \left( \delta_{rc} k_{ds} \right) \\ &= \delta_{ic} k_{dj} - \sum_s a_{ij} a_{cs} k_{ds} \end{align}$$

Note this is a 4D tensor with dimensions $3 \times 3 \times 3 \times 2$, since the indices $i,j,c$ run from 1 to 3 and $d$ runs from 1 to 2.

Related Question