Solved – Why do we need the temperature in Gumbel-Softmax trick

autoencodersgumbel distribution

Assuming a discrete variable $z_j$ with unnormalized probability $\alpha_j$, one way to sample is to apply argmax(softmax($\alpha_j$)), another is to do the Gumbel trick argmax($\log\alpha_j+g_j$) where $g_j$ is gumbel generated noise. This second approach is useful if we want to do something like variational auto encoding (i.e., encode an input $x_j$ into a latent discrete variable $z_j$). Then, if the goal was to have the full distribution over possible outcomes for $z_j$, we can use softmax transformation on top of the perturbation with Gumbel noise:
$$\pi_j = \frac{e^{\log
\alpha_j+g_j}}{\sum_{k=1}^{k=K}e^{\log
\alpha_k+g_k}}\ \ \ \text{where}\ \ g_k=-\log(-\log(\epsilon\sim {U}(0,1))).$$
Why this isn't enough? Why do we need to include the temperature $\tau$ term in this? And rewrite,
$$\pi_j = \frac{e^{\frac{\log\alpha_j+g_j}{\tau}}}{\sum_{k=1}^{k=K}e^\frac{\log
\alpha_k+g_k}{\tau}}\ \ \ \text{where}\ \ g_k=-\log(-\log(\epsilon\sim {U}(0,1)))$$
I understand that the temperature makes the vector $\pi=[\pi_1, …,\pi_k]$ smoother or rougher (i.e., high temperature just makes all $\pi_i$s to be the same, and generates a flatter distribution, and $\tau=1$ just makes the two equations identical) but why do we need it in practice? All we want (i.e., in VAE) is to decouple the stochastic aspect of the sampling (i.e, move the stochastic part of it to the input) which is achieved by the Gumbel trick, and then somehow replace the one-hot vector draw with a continuous vector, which we get by doing the softmax($\log\alpha_j+g_j$) which we will get by using the first equation. I am sure I am missing something fundamental, but can't see what it is…

Best Answer

one way to sample is to apply argmax(softmax($\alpha_j$))

That is hardly "sampling", given that you deterministically pick the largest $\alpha_j$ every time. (also, you said that $\alpha$ is the unnormalized probability but that doesn't make sense seeing as log probabilities go into the softmax). The correct way to sample would be sample(softmax($x$)), where $x$ are the logits. Indeed, the goal of gumbel-softmax is not to replace the softmax operation as you've written it, but the sampling operation:

We can replace sample($p$) where $p$ are a vector of probabilities with argmax($\log p + g$) where $g$ is the gumbel noise. Of course, this is equivalent to argmax($x + g$) where $x$ are again the logits. To conclude, sample(softmax($x$)) and argmax($x+g)$ are equivalent procedures.

Then, if the goal was to have the full distribution over possible outcomes for $z_j$, we can use softmax transformation on top of the perturbation with Gumbel noise.

In fact you already have a distribution over all possible outcomes.

However, argmax($x+g$) is not differentiable wrt $x$, therefore to backpropagate we replace its gradient with the gradient of softmax($(x+g)\tau^{-1}$). When $\tau \rightarrow 0$, the expression approaches argmax.

Picking a reasonable, small values of $\tau$ will ensure a good estimate of the gradient while ensuring that the gradients are numerically well behaved.

and $\tau=1$ just makes the two equations identical

In fact, there is no special significance to $\tau = 1$. Rather, $\tau \rightarrow 0$ makes the gradient estimate unbiased but high in variance, where as larger values of $\tau$ add more bias to the gradient estimate but lower the variance.

Related Question