Solved – Gradient of loss function for (non)-linear prediction functions

gradientloss-functionsoptimization

$
\newcommand{\y}{\mathbf{y}}
\newcommand{\wv}{\mathbf{w}}
\newcommand{\xv}{\mathbf{x}}
\newcommand{\loss}{L(\wv;\xv, y)}
$
I'm trying to clear up the calculation of the gradient of a loss function, and how it changes when your prediction function is non-linear.

I'd like to verify my thinking is correct:

We are trying to predict a scalar $y_i$, using a set of weights $\wv$ for each example $\xv_i$, where $\wv, \xv$ are vectors. In the following we omit the indicator $i$.

In order to do that, we want to minimize a loss function $\loss = f\,(y, pred(\xv, \wv))$,
which depends on $y, \xv$ and $\wv$. If we take squared loss as a specific example then (disregarding any regularization):

$$
\loss = (y – pred(\xv, \wv))^2
$$

where $pred(\xv, \wv)$ is the prediction function. When we are using a linear prediction this would be $pred(\xv, \wv) = \wv^T\xv$

Now, in order to minimize the loss, using for example a first order method such as stochastic gradient descent, we need to find the gradient of the loss function with respect to $\wv$. So we have:

$$
\nabla_w \loss = \nabla_w f\,(y, pred(\xv, \wv))
\\= \frac{\partial f\,(y, pred(\xv, \wv))}{\partial w} \cdot \nabla_w pred(\xv, \wv)
\tag{lossGradient}\label{lossGradient}
$$

This is where I get a bit lost with the chain rule so correct me if I'm wrong:

The first element in $\eqref{lossGradient}$ would be a scalar (i.e. the derivative of the loss function w.r.t. $\wv$), and the second element is a vector (i.e. the gradient of the prediction function w.r.t. to $\wv$ again). Both of these depend on the prediction function.

In the case of linear prediction we have as we saw $pred(\xv, \wv) = \wv^T\xv$
and $\nabla_w pred(\xv, \wv) = \xv$.

The difference with a non-linear prediction function would be that the result of $\nabla_w pred(\xv, \wv)$ might produce a vector that is not $\xv$ as the gradient, but the calculation of the loss gradient itself is still the same, i.e. $\text{loss derivative} * \text{prediction gradient}$.

Is that correct? Could you give me some examples of non-linear prediction functions that have a gradient that is not the feature vector $\xv$?

Best Answer

The chain rule is $$ (f(g(x)))' = f'(g(x))g'(x). $$ In a multivariate setting, with $f:\mathbb{R}\rightarrow\mathbb{R}$ and $g:\mathbb{R}^p\rightarrow\mathbb{R}$ this becomes $$ \nabla f(g(w)) = \nabla g(w)f'(g(w)). $$ Remember that the chain rule is $$ f(g(x)) = \frac{\partial z}{\partial x}\frac{\partial f}{\partial z}, $$ where $z = g(x)$.

Thus, if we let $$ L(w;X,y) = f(y, \text{pred}(X, w)) := f(y - \text{pred}(X, w)) = \|y - \text{pred}(X, w)\|_2^2 = \|y - Xw\|_2^2, $$ then $$ f(z) = \|z\|_2^2 = z^tz $$ and $$ z = g(w) = y - \text{pred}(X, w) = y - Xw. $$ We have, $$ \nabla f(z) = \nabla z^tz = 2z $$ and $$ \nabla g(w) = \nabla y - Xw = X^t. $$

Hence, the gradient of your linear estimator is $$ \nabla L(w;x,y) = \nabla g(w)\nabla f(z) = X^t\cdot2z = X^t\cdot2(y - Xw). $$

An example of a non-linear prediction function that has a gradient that is not the feature vector x would be e.g. logistic regression.

Related Question