Jade Sea

Welcome, delicious friend!

Follow me on GitHub

Auto-Encoding Variational Bayes & Application to Biomedical Segmentation

Journal Club Notes by Jianxiao

This note includes general derivations1 and specific derivations for segmentation application2.

1. Generic Auto-Encoding Variational Bayes (AEVB)

1.1. Problem scenario

figure 1

Figure 1. Directed graphical model considered

Consider a model consisting of a dataset $X = { x^{(i)} }^N_{i=1}$ and its associated latent variables $z$, shown in Figure 1. Values of variable $z$ are generated from a prior distribution $p_\theta(z)$; values of variable $x$ are generated from the conditional distribution $p_\theta(x|z)$.

Inference based on this model can be intractable. Specifically, the marginal likelihood $p_\theta(x) = \int p_\theta(z)p_\theta(x|z) \mathop{dz}$ is intractable as it requires integrating over the latent variable $z$. As a result, the posteriror density $p_\theta(z|x) = \frac{p_\theta(x|z)p_\theta(z)}{p_\theta(x)}$ is also intractable as it involves evaluating $p_\theta(x)$ first.

The Auto-Encoding Variational Bayes (AEVB) approach is proposed for inference of marginal likelihood $p_\theta(x)$ and posterior density $p_\theta(z|x)$, the approximate of which also help to estimate the parameters $\theta$ using ML or MAP.

1.2. Variational lower bound

We start by introducing an approximation $q_\phi(z|x)$ to the intractable true posterior $p_\theta(z|x)$. We can consider the KL divergence between these two distribution:

Rearranging the terms gives:

Since the first term is the non-negative KL divergence, the second term is the lower bound of the log marginal likelihood $\log p_\theta(x)$, also called the variational lower bound. The variational lower bound can be written in two forms:

1) The joint-contrastive form

2) The prior-contrastive form

The prior-contrastive form can be related to auto-encoders: the KL term acts as a regularisation term while the expectation term is the reconstruction error ($z$ is sampled from $x$ by $g(\epsilon,x)$; $x$ is then sampled from $z$ by $p_\theta(x|z)$).

1.3. Stochastic Gradient Variational Bayes (SGVB) estimator

In order to optimise the $\mathcal{L}(\theta,\phi;x)$, we need to differentiate it w.r.t. $\theta$ and $\phi$. However, differentiating w.r.t. $\phi$ is problematic as the usual Monte Carlo gradient estimator exhibits high variance3. The usual Monte Carlo method estimates the terms by sampling $z$ from $q_\phi(z|x)$:

Instead, we use a reparametrisation trick so that the Monte Carlo estimate becomes differentiable. We reparametrise $z$ as:

where the transformation $g_\phi(\cdot)$ and auxiliary noise $\epsilon$ are chosen based on the form of $q_\phi(z|x)$.

figure 2

Figure 2 How reparametrisation trick changes network structure

Essentially, we are assuming that sampling $z$ from $q_\phi(z|x)$ is equivalent to sampling $\epsilon$ from $p(\epsilon)$ (Figure 2). As a result, we can rewrite integrals involving $q_\phi(z|x)$ in forms involving $\epsilon$ instead:

We can then compute the SGVB estimator for both forms of variational lower bound:

1) The joint-contrastive form

2) The prior-contrastive form

The prior-contrastive form typically has less variance, but relies on the KL term having closed form to work. In particular, if both the true prior and the approximate posterior are Gaussian, the KL term can be easily computed.

2. Application to biomedical segmentation

2.1. Problem scenario

CNN-based Biomedical segmentation usually requires large set of annotated data and does not take into account of anatomical knowledge. In the Dalca et al. paper2, an two-step AEVB approach is proposed to carry out unsupervised segmentation learning, with anatomical prior incorporated.

Consider a model consisting of image data $x$, its segmentation map $s$, and the latent variable $z$, which represents the underlying shape embedding (Figure 3).

figure 3

Figure 3 Generative model for the segmentation problem

The model generates values based on the following prior and likelihood densities:

1) the prior probability of the latent variable $z$ is modeled by a standard normal distribution

2) the labels $s$ are drawn from a categorical distribution given $z$

where $f_{j,m}(\cdot; \theta_{s|z})$ is the probability of label $m$ at voxel $j$

3) the voxel intensity values $x$ are drawn from normal distributions with diagonal covariance matrix

where $\delta(s[j]=m) = 1$ if $s[j]=m$ and $0$ otherwise.

We aim to learn segmentation from an image using maximum a posteriori (MAP) estimation:

However, $p_\theta(s,x)$ is intractable as it requires integrating over $z$:

Similarly, $p_\theta(z|x,s)$ is also intractable, which mean EM algorithm does not work in this case.

Instead, we will optimise $\log p_\theta(x)$ and $\log p_\theta(s)$ by optimising their variational lower bounds respectively. Note that we cannot optimise the two probability distributions together because we do not have paired sets of $x$ and $s$ in unsupervised learning setting. Nevertheless, we want to learn mappings between the latent space and both the image space and the label space, so that we can apply anatomical knowledged learnt from label-latent mapping to new images.

2.2. SGVB estimator derivations

2.2.1. Learning anatomical prior

Using the AEVB framework, we approximate the true posterior $p_\theta(z|s)$ with $q_\phi(z|s)$. $q_\phi(z|s)$ is modeled as a normal distribution with diagnoal covariance matrix:

The respective variational lower bound is:

Using the reparametrisation trick, we can estimate the expectation term as:

Since both the true prior $p(z)$ and the approximate posterior $q_\phi(z|s)$ are gaussian, the KL term has a closed form (see Appendix A for the proof):

2.2.2. Unsupervised learning

Similar to the previous section, we approximate the true posterior $p_\theta(z|x)$ with $q_\phi(z|x)$. $q_\phi(z|x)$ is modeled as a normal distribution with diagnoal covariance matrix:

The respective variational lower bound is:

Again, the KL term has a closed-form solution:

The expectation term, on the other hand, is intractable as we cannot compute $p_\theta(x|z)$. Instead, we will use a lower bound of the term derived based on Jensen’s inequality (more details in Appendix B):

We can then estimate the expectation term, sampling $z$ using the reparametrisation trick:

2.3. Network architecture

Two CNNs are used, one for learning the anatomical prior and one for unsupervised learning of segmentation from image.

Figure 4 shows the auto-encoder architecture for learning anatomical prior. This network learns the mapping between the latent space and the label space, using only segmentation maps.

figure 4

Figure 4 Auto-encoder to learn anatomical prior

Given an input segmentation map $s$, the encoder outputs the parameter of the posterior $q_\phi(z|s)$, $\mu_{z|s}$ and $\sigma_{z|s}$. The latent variable $z$ is then sampled using reparametrisation trick, where $z \sim g(\epsilon,s) = \mu_{z|s} + \sigma_{z|s} \epsilon$. The transformation $g(\cdot)$ approximates normal distribution by adding a scaled noise (or variance) to the mean. Finally, $s$ is reconstructed by the decoder with $p_\theta(s|z)$, represented by the top output layer in Figure 4.

An additional prior is added in actual implementation, represented by the bottom output layer in Figure 4. This location prior $p_{loc}(s)$ is computed as the voxel-wise frequency of labels in the prior training dataset. By multiplying the prior $p_{loc}(s)$ with the reconstructed segmentation map from $p_\theta(s|z)$ (adding the logarithms in actual implementation), the output segmentation is obtained.

The parameters of $q_\phi(z|s)$ and $p_\theta(s|z)$, as well as the output segmentation labels $s[j]$, are then used to estimate and optimise the SGVB lower bound.

Figure 5 shows the architecture for unsupervised learning of segmentation from input images. The decoder part of this network is copied from the prior-learning network from Figure 4, which helps to incorporate the prior anatomical knowledge learnt.

figure 5

Figure 5 Architecture for unsupervised learning

The image encoder is first trained in a similar fashion as the prior auto-encoder; the pre-trained weights are used for initialisation during actual training.

Given an input scan $x$, the image encoder outputs the parameter of the posterior $q_\phi(z|x)$, $\mu_{z|x}$ and $\sigma_{z|x}$. The latent variable $z$ is then sampled using reparametrisation trick, where $z \sim g(\epsilon,x) = \mu_{z|x} + \sigma_{z|x} \epsilon$. Then, the already trained prior decoder outputs $s$, which is used to generate the reconstructed $x$ with $p_\theta(x|s)$.

The parameters of $q_\phi(z|x)$ and $p_\theta(x|s)$, as well as the reconstructed image $x$, are then used to estimate and optimise the SGVB lower bound.

2.4. Results

The networks are tested on two datasets (Figure 6):

1) T1w scan dataset: more than 14,000 T1-weighted MRI scans from ADNI, ABIDE, GSP, etc., all resampled to 256x256x256 (1mm isotropic) and cropped to 160x192x224 to remove entirely-background voxels.

2) T2-FLAIR scan dataset: 3800 T2-FLAIR scans from ADNI (5mm slice spacing), linearly registered to T1w images (using ANTs). No ground truth.

figure 6

Figure 6 Example T1w and T2-FLAIR images

A prior training set is composed using 5000 T1w images with ground truth segmentation maps generated by FreeSurfer (with manual correction and QC). This set is used to train the prior auto-encoder. The rest of the T1w images are split into trianing, validation and test sets for unsupervised learning. In the T2 scan case, subjects common to both datasets are excluded from the prior training set.

Figure 7 shows segmentation results from 3 subjects in the T1w scan dataset, in comparison to their respective ground truth. Figure 8 shows one subject’s segmentation results from the T2-FLAIR scan dataset. The estimated segmentation boundaries are generally aligned with the ground truth boudnaries, but much smoother, lacking of fine details.

figure 7

Figure 7 Example segmentation results for T1w scan dataset

figure 8

Figure 8 Example segmentation results for T2-FLAIR scan dataset

Appendix

A. Closed-form solution for KL divergence

When two distributions are both gaussian, the KL divergence distance between them has a closed form. I will start with the proof for the generic case, and then show the solutions for a special case.

Consider two mutivariate normal distribution of dimension $k$, $p(x) \sim \mathcal{N} (\mu_1, \Sigma_1)$ and $q(x) \sim \mathcal{N} (\mu_2, \Sigma_2)$. There is no constraint on the form of covariance matrix. Since the KL divergence is defined as $KL(p(x)||q(x)) = \mathbb{E}_p [\log p(x) - \log q(x)]$ for continuous distributions, we start by expressing log likelihood for the two distributions.

For $p(x)$:

where is the matrix determinant.

Similarly for q(x):

Then we have:

Some tricks used are:

1) trace trick: $\mathbb{E}[x] = \mathbb{E}[\mathrm{Tr}(x)]$ if x is scalar

2) $\mathbb{E}_p [ (x-\mu_1)^T \Sigma_2^{-1} (\mu_1-\mu_2)]$: this is equal to 0. The expression is essentially the expectation of a scaled and rotated version of $(x-\mu_1)$, which are still straight lines passing through the origin.

Now, consider a special case, where $p(x) \sim \mathcal{N} (\mu_1, \sigma_1^2)$ and $q(x) \sim \mathcal{N} (0, \mathrm{I})$. In this case, we will be able to simplify the solution even further:

This is the solution used in section 2.2.

B. Jensen’s inequality

Jensen’s inequality states that the secant line of a convex function lies above the graph of the function. The concept is intuitive.

jensen

Mathematically, for any convex function $f$, Jensen’s inequality is:

Conversely, for any concave function $g$, the sign is simply flipped:

When probability density functions are involved (e.g. when computing expectation), Jensen’s inequality is:

where $f$ is the convex function, $g$ is any real-valued measurable function, and $p(x)$ is the probability density function (i.e. $\int_p p(x) \mathop{dx} = 1$).

In section 2.2.2, $\log$ is a concave function, and $p_\theta(s|z)$ is the probability density function. Therefore:

  1. Diederik P Kingma and Max Welling. “Auto-Encoding Variational Bayes”. In: ArXiv e-prints (2013). arXiv: 1312.6114 

  2. Adrian V Dalca, John Guttag, and Mert R Sabuncu. “Anatomical Pri- ors in Convolutional Networks for Unsupervised Biomedical Segmenta- tion”. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018, pp. 9290–9299.  2

  3. David M Blei, Michael I Jordan, and John W Paisley. “Variational Bayesian Inference with Stochastic Search”. In: Proceedings of the 29th International Conference on Machine Learning (ICML-12). 2012, pp. 1367– 1374.