Max Log Probability vs Probability – Why Optimize the Maximum Log Probability?

likelihoodoptimizationprobability

In most machine learning tasks where you can formulate some probability $p$ which should be maximised, we would actually optimize the log probability $\log p$ instead of the probability for some parameters $\theta$. E.g. in maximum likelihood training, it's usually the log-likelihood. When doing this with some gradient method, this involves a factor:

$$ \frac{\partial \log p}{\partial \theta} = \frac{1}{p} \cdot \frac{\partial p}{\partial \theta} $$

See here or here for some examples.

Of course, the optimization is equivalent, but the gradient will be different, so any gradient-based method will behave different (esp. stochastic gradient methods).
Is there any justification that the $\log p$ gradient works better than the $p$ gradient?

Best Answer

Gradient methods generally work better optimizing $\log p(x)$ than $p(x)$ because the gradient of $\log p(x)$ is generally more well-scaled. That is, it has a size that consistently and helpfully reflects the objective function's geometry, making it easier to select an appropriate step size and get to the optimum in fewer steps.

To see what I mean, compare the gradient optimization process for $p(x) = \exp(-x^2)$ and $f(x) = \log p(x) = -x^2$. At any point $x$, the gradient of $f(x)$ is $$f'(x) = -2x.$$ If we multiply that by $1/2$, we get the exact step size needed to get to the global optimum at the origin, no matter what $x$ is. This means that we don't have to work too hard to get a good step size (or "learning rate" in ML jargon). No matter where our initial point is, we just set our step to half the gradient and we'll be at the origin in one step. And if we don't know the exact factor that is needed, we can just pick a step size around 1, do a bit of line search, and we'll find a great step size very quickly, one that works well no matter where $x$ is. This property is robust to translation and scaling of $f(x)$. While scaling $f(x)$ will cause the optimal step scaling to differ from 1/2, at least the step scaling will be the same no matter what $x$ is, so we only have to find one parameter to get an efficient gradient-based optimization scheme.

In contrast, the gradient of $p(x)$ has very poor global properties for optimization. We have $$p'(x) = f'(x) p(x)= -2x \exp(-x^2).$$ This multiplies the perfectly nice, well-behaved gradient $-2x$ with a factor $\exp(-x^2)$ which decays (faster than) exponentially as $x$ increases. At $x = 5$, we already have $\exp(-x^2) = 1.4 \cdot 10^{-11}$, so a step along the gradient vector is about $10^{-11}$ times too small. To get a reasonable step size toward the optimum, we'd have to scale the gradient by the reciprocal of that, an enormous constant $\sim 10^{11}$. Such a badly-scaled gradient is worse than useless for optimization purposes - we'd be better off just attempting a unit step in the uphill direction than setting our step by scaling against $p'(x)$! (In many variables $p'(x)$ becomes a bit more useful since we at least get directional information from the gradient, but the scaling issue remains.)

In general there is no guarantee that $\log p(x)$ will have such great gradient scaling properties as this toy example, especially when we have more than one variable. However, for pretty much any nontrivial problem, $\log p(x)$ is going to be way, way better than $p(x)$. This is because the likelihood is a big product with a bunch of terms, and the log turns that product into a sum, as noted in several other answers. Provided the terms in the likelihood are well-behaved from an optimization standpoint, their log is generally well-behaved, and the sum of well-behaved functions is well-behaved. By well-behaved I mean $f''(x)$ doesn't change too much or too rapidly, leading to a nearly quadratic function that is easy to optimize by gradient methods. The sum of a derivative is the derivative of the sum, no matter what the derivative's order, which helps to ensure that that big pile of sum terms has a very reasonable second derivative!