Solved – Variational autoencoder with L2-regularisation

autoencoderskerasneural networksregularizationtensorflow

I have built a variational autoencoder (VAE) with Keras in Tenforflow 2.0, based on the following model from Seo et al. (link to paper here). The VAE is used for image reconstruction.
enter image description here

Note that the two layers with dimensions 1x1x16 output mu and log_var, used for the calculation of the Kullback-Leibler divergence (KL-div).

In my architecture, the sampling of a value from the latent space is implemented with a Lambda layer:

lat_var = Lambda(sampling, output_shape=(1, 1,16), name='latent')([z_mean, z_log_var])

with sampling implemented the following way:

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape =(1,1,16))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

I was wondering if it makes sense to impose l2 regularization on the layers underlined in red, since the KL-div already imposes a constraint and acts as a regularization term.

Mathematical background: The objective function for the VAE is the mean of the reconstruction loss (in red) and the KL-div (in blue), as shown in the formula from Seo et al.

enter image description here

During optimization, minimization of the objective function leads to both minimizing the rec loss and the KL-div. So the KL-div puts a constraint and acts as a regularization term. If we add L2-regularization to the objective function, this would add an additional constraint, penalizing higher weights (see Andrew Ng on L2-regularization) in the marked layers.

Best Answer

The loss term underlined with red marker is the reconstruction loss between the input to the reconstruction of the input(paper is about on reconstruction!) , not L2 regularization .

VAE's loss has two components: reconstruction loss(since autoencoder's aim to learn to reconstruct) and KL loss (to measure how much information is lost or how much we have diverged from the prior). The actual form of the VAE loss(aim is to maximize this loss) is :

$$ L(\theta , \phi) = \sum_{i=1}^{N} E_{z_{i} \sim q_{\phi}(z|x_{i})} \left [ log p_{\theta} (x_{i}|z)\right] - KL(q_{\phi} (z | x_{i}) || p(z)) $$ where $\left (x , z \right)$ is input and latent vector pair. Encoder and decoder networks are $q$ and $p$ respectively. Since, we have a Gaussian prior, reconstruction loss becomes the squared difference(L2 distance) between input and reconstruction.(logarithm of gaussian reduces to squared difference).

To get a better understanding of VAE, let's try to derive VAE loss. Our aim is to infer good latents from the observed data. However, there's a vital problem: given an input there's no latent pair we have and even if we had it, it is no use. To see why, concentrate on Bayes' theorem:

$$ p(z|x) = \frac{p(x|z)p(z)}{p(x)} = \frac{p(x|z)p(z)}{\int p(x|z)p(z)dz} $$

the integral in the denominator is intractable. So, we have to use approximate Bayesian inference methods. The tool we're using is mean-field Variational Bayes, where you try to approximate the full posterior with a family of posteriors. Say our approximation is $q_{\phi}(z|x)$. Our aim now becomes how good the approximation is . This can be measured via KL divergence:

\begin{align} q^{*}_{\phi} (z|x) &= argmin_{\phi} KL (q_{\phi}(z | x) || p(z | x))) \\ &= argmin_{\phi} \left ( E_{q} \left [ log q_{\phi} (z|x)\right] - E_{q} \left [ log p(z , x)\right] + log p(x) \right ) \end{align}

Again, due to $p(x)$, we cannot optimize the KL dicvergence directly. SO, leave that term alone !

$$ log p(x) = KL (q_{\phi}(z | x) || p(z | x))) - \left ( E_{q} \left [ log q_{\phi} (z|x)\right] - E_{q} \left [ log p(z , x)\right] \right ) $$

We try to minimize the KL divergence and this divergence is non-negative. Also, $ log p(x)$ is constant. So, minimizing KL is equivalent to maximizing the other term which is called evidence lower bound(ELBO). Let's rewrite the ELBO then :

\begin{align} ELBO(\phi) &= E_{q} \left[ logp(z , x) \right] - E_{q} \left[log q_{\phi}(z|x)\right] \\ &= E_{q} \left [ log p(z | x) \right] + E_{q} \left [ log p(x)\right] - E_{q} \left [ log q_{\phi} (z|x)\right] \\ &= E_{q} \left [ log p(z | x) \right] - KL( q_{\phi} (z|x) || p(x)) \end{align}

Then, you have to maximize ELBO for each datapoint.

L2 regularization(or weight decay) is different from reconstruction as it is used to control network weights. Of course you can try L2 regularization if you think that your network is under/over fitting. Hope this helps!