[Math] Gradient of a softmax applied on a linear function

derivativesgradient descentlinear algebrapartial derivative

I am trying to calculate the softmax gradient:
$$p_j=[f(\vec{x})]_j = \frac{e^{W_jx+b_j}}{\sum_k e^{W_kx+b_k}}$$
With the cross-entropy error:
$$L = -\sum_j y_j \log p_j$$
Using this question I get that
$$\frac{\partial L}{\partial o_i} = p_i – y_i$$
Where $o_i=W_ix+b_i$

So, by applying the chain rule I get to:
$$\frac{\partial L}{\partial b_i}=\frac{\partial L}{\partial o_i}\frac{\partial o_i}{\partial b_i} = (p_i – y_i)1=p_i – y_i$$
Which makes sense (dimensionality wise)
$$\frac{\partial L}{\partial W_i}=\frac{\partial L}{\partial o_i}\frac{\partial o_i}{\partial W_i} = (p_i – y_i)\vec{x}$$
Which has a dimensionality mismatch

(for example if dimensions are $W_{3\times 4},\vec{b}_4,\vec{x}_3$)

What am I doing wrong ? and what is the correct gradient ?

Best Answer

You can use differentials to tackle the problem.

Define the auxiliary variables $$\eqalign { o &= Wx+b \cr e &= \exp(o) \cr p &= \frac{e}{1:e} \cr }$$ with their corresponding differentials $$\eqalign { do &= dW\,x + db \cr de &= e\odot do \cr dp &= \frac{de}{1:e} - \frac{e(1:de)}{(1:e)^2} \,\,\,\,=\,\, (P-pp^T)\,do \cr }$$where : denotes the double-dot (aka Frobenius) product, and $\odot$ denotes the element-wise (aka Hadamard) product, and $P = \operatorname{Diag}(p)$.

Now substitute these into the cross-entropy function, and find its differential $$\eqalign { L &= -y:\log(p) \cr\cr dL &= -y:d\log(p) \cr &= -y:P^{-1}dp \cr &= -y:P^{-1}(P-pp^T)\,do \cr &= -y:(I-1p^T)\,do \cr &= (p1^T-I)y:(dW\,x + db) \cr &= (p1^T-I)yx^T:dW + (p1^T-I)y:db \cr\cr }$$ Setting $db=0$ yields the gradient wrt $W$ $$\eqalign { \frac{\partial L}{\partial W} &= (p1^T-I)yx^T \cr &= (p-y)x^T \cr }$$ while setting $dW=0$ yields the gradient wrt $b$ $$\eqalign { \frac{\partial L}{\partial b} &= (p1^T-I)y \cr &= p-y \cr }$$ Note that in the above derivation, the $\log$ and $\exp$ functions are applied element-wise to their vector arguments.

Based on your expected results, you appear to use an unstated constraint that $1^Ty=1$, which I have used to simplify the final results.

Related Question