Solved – Variance of reparameterization trick and score function

gradient descentpolicy gradientreinforcement learningscoring-rulesvariance

For a function $\mathbf E_{z\sim q_\phi(z|x)}[f(z)]$(assuming $f$ is continuous), where $q_\phi$ is a Gaussian distribution, if we want to compute the gradient w.r.t. $\phi$, we have two way to do that.

  1. compute the score function estimator:

    $$
    \begin{align}
    \nabla_\phi\mathbf E_{z\sim q_\phi(z|x)}[f(z)]&=\nabla \int f(z)q_\phi(z|x)dz\\
    &=\int f(z)\nabla q_\phi(z|x)dz\\
    &=\int {q_\phi(z|x)\over q_\phi(z|x)}f(z)\nabla_\phi q_\phi(z|x)dz\\
    &=\int q_\phi(z|x)f(z)\nabla_\phi \log q_\phi(z|x)dz\\
    &=\mathbf E_{z\sim q_\phi(z|x)}[f(z)\nabla_\phi\log q_\phi(z|x)]\tag 1
    \end{align}
    $$

  2. use reparameterization trick: let $z=\mu_\phi(x)+\epsilon\sigma_\phi(x)$, where $\epsilon\sim\mathcal N(0,1)$, we then differentiate the objective and have
    $$
    \nabla_\phi\mathbf E_{z\sim q_\phi(z|x)}[f(z)]=\mathbf E_{\epsilon\sim\mathcal N(0,1)}[\nabla_\phi f(\mu_\phi(x)+\epsilon\sigma_\phi(x))]\tag 2$$
    .

According to this video, at around 58min the instructor explains that computing the gradient using the reparameterization trick generally has lower variance than the score function estimator. Here's my understanding according to the instructor's explanation, which I'm not so sure if I take right. Welcome to point out the misunderstanding 🙂

Eq.$(1)$ has high variance because $f(z)$ is computed from samples. whose variance is unbound. Multiplying it to $\nabla_\phi \log q_\phi(z|x)$, therefore, results in the gradient having high variance. On the other hand, the coefficient of Eq.$(2)$ is fixed except $\epsilon$, which has its variance $1$. As a result, Eq.$(1)$ has higher variance than Eq.$(2)$.

Best Answer

The notation here is far more complicated than it needs to be, and I suspect this is contributing to the issue of understanding this method. To clarify the problem, I'm going to re-frame this in standard notation. I'm also going to remove reference to $x$, because the entire analysis is conditional on this value, so it adds nothing to the problem beyond complicating the notation.

You have a problem with a Gaussian random variable $Z \sim \text{N}(\mu(\phi), \sigma(\phi)^2)$, where the mean and variance depend on a parameter $\phi$. You can also define the error term $\epsilon \equiv (Z - \mu(\phi))/\sigma(\phi)$ which measures the number of standard deviations from the mean. Now, you want to compute the gradient of the expected value:

$$\begin{equation} \begin{aligned} J(\phi) \equiv \mathbb{E}(r(Z)) &= \int \limits_\mathbb{R} r(z) \cdot \text{N}(z|\mu(\phi), \sigma(\phi)^2) \ dz \\[6pt] &= \int \limits_\mathbb{R} r(\mu(\phi) + \epsilon \cdot \sigma(\phi)) \cdot \text{N}(\epsilon|0,1) \ dz. \\[6pt] \end{aligned} \end{equation}$$

(The equivalence of these two integral expressions is a consequence of the change-of-variable formula for integrals.) Differentiating these expressions gives the two equivalent forms:

$$\begin{equation} \begin{aligned} \nabla_\phi J(\phi) &= \int \limits_\mathbb{R} r(z) \bigg( \nabla_\phi \ln \text{N}(z|\mu(\phi), \sigma(\phi)^2) \bigg) \cdot \text{N}(z|\mu(\phi), \sigma(\phi)^2) \ dz \\[6pt] &= \int \limits_\mathbb{R} \bigg( \nabla_\phi r(\mu(\phi) + \epsilon \cdot \sigma(\phi)) \bigg) \cdot \text{N}(\epsilon|0,1) \ dz. \\[6pt] \end{aligned} \end{equation}$$

Both of these expressions are valid expressions for the gradient of interest, and both can be approximated by corresponding finite sums from simulated values of the random variables in the expressions. To do this we can generate a finite set of values $\epsilon_1,...,\epsilon_M \sim \text{IID N}(0,1)$ and form the values $z_1,...,z_M$ that correspond to these errors. Then we can use one of the following estimators:

$$\begin{equation} \begin{aligned} \nabla_\phi J(\phi) \approx \hat{E}_1(\phi) &\equiv \frac{1}{M} \sum_{j=1}^M r(z_j) \bigg( \nabla_\phi \ln \text{N}(z_j|\mu(\phi), \sigma(\phi)^2) \bigg), \\[10pt] \nabla_\phi J(\phi) \approx \hat{E}_2(\phi) &\equiv \frac{1}{M} \sum_{j=1}^M \nabla_\phi r(\mu(\phi) + \epsilon_j \cdot \sigma(\phi)). \end{aligned} \end{equation}$$

The speaker asserts (but does not demonstrate) that the variance of the second estimator is lower than the variance of the first. He claims that this is because the latter estimator uses direct information about the gradient of $r$ whereas the first estimator uses information about the gradient of the log-density for the normal distribution. Personally, without more knowledge of the nature of $r$, this seems to me to be an unsatisfying explanation, and I can see why you are confused by it. I doubt that this result would hold for all functions $r$, but perhaps within the context of that field, the function $r$ tends to have a gradient that is fairly insensitive to changes in the argument value.

Related Question