Computing the log of a sum of exponentials

algebra-precalculusexponential functionlogarithms

in a Coursera course by UW I've come across this piece of code computing the log of a sum of exponentials.

def log_sum_exp(Z):
    """ Compute log(\sum_i exp(Z_i)) for some array Z."""
    return np.max(Z) + np.log(np.sum(np.exp(Z - np.max(Z))))

I've been trying to figure out how this computes:

$$log (\Sigma_{i}^n e^{Z_i})$$

I tried factoring it for some time now but at least for today, I'm at my wit's end.
Could someone explain, please?

Thank you very much!

Best Answer

This is a widely-applied trick in numerical calculations to counter the problem of possible overflow, caused by direct log_sum_exp computation, namely by the calculation of the exponent of the maximal value $x_{\mathrm{max}} = \underset{i}{\max}\{x_i\}$ from the given sequence of $x = (x_1,~\dots~,x_n)^{\mathsf{T}}$.

Let us first introduce the following function: $$ f(x) = \mathrm{log\_sum\_exp}(x) = \log\sum\limits_{i=1}^n \exp(x_i) \tag{1} $$

Now, let us apply exponent to both sides of $(1)$: $$ \exp(f(x)) = \sum\limits_{i=1}^n \exp(x_i) \tag{2} $$ Now, as we do not want an overflow to happen, let us try to counter the effect of computing the exponent of the maximal value in the sequence of $x$: $\exp(\underset{j}{\max}\{x_j\})$ by taking it out of the sum on the right of (2) by extracting it from every term in the sequence of summation: $$ \exp(f(x)) = \exp(\underset{j}{\max}\{x_j\})\sum\limits_{i=1}^n \exp(x_i - \underset{j}{\max}\{x_j\}) \tag{3} $$ So that now, in the summation sequence $\underset{j}{\max}\{x_j\}$ will be cancelled by itself. Now, we can re-apply log on both sides of $(3)$, and obtain: $$ f(x) = \mathrm{log\_sum\_exp}(x) = \underset{j}{\max}\{x_j\} + \log\sum\limits_{i=1}^n\exp(x-\underset{j}{\max}\{x_j\})) \tag{4} $$

Such a trick allows us not to compute the exponent of the largest number in a given sequence, as in case of numpy library you are using, we first compute Z - np.max(Z), after what np.exp is applied.

Hope, I was able to answer your question. If something is still unclear to you, I recommend reading through this small article :)

Related Question