AutoEncoding Variational Bayes & Application to Biomedical Segmentation
Journal Club Notes by Jianxiao
This note includes general derivations^{1} and specific derivations for segmentation application^{2}.
1. Generic AutoEncoding Variational Bayes (AEVB)
1.1. Problem scenario
Figure 1. Directed graphical model considered
Consider a model consisting of a dataset $X = { x^{(i)} }^N_{i=1}$ and its associated latent variables $z$, shown in Figure 1. Values of variable $z$ are generated from a prior distribution $p_\theta(z)$; values of variable $x$ are generated from the conditional distribution $p_\theta(xz)$.
Inference based on this model can be intractable. Specifically, the marginal likelihood $p_\theta(x) = \int p_\theta(z)p_\theta(xz) \mathop{dz}$ is intractable as it requires integrating over the latent variable $z$. As a result, the posteriror density $p_\theta(zx) = \frac{p_\theta(xz)p_\theta(z)}{p_\theta(x)}$ is also intractable as it involves evaluating $p_\theta(x)$ first.
The AutoEncoding Variational Bayes (AEVB) approach is proposed for inference of marginal likelihood $p_\theta(x)$ and posterior density $p_\theta(zx)$, the approximate of which also help to estimate the parameters $\theta$ using ML or MAP.
1.2. Variational lower bound
We start by introducing an approximation $q_\phi(zx)$ to the intractable true posterior $p_\theta(zx)$. We can consider the KL divergence between these two distribution:
Rearranging the terms gives:
Since the first term is the nonnegative KL divergence, the second term is the lower bound of the log marginal likelihood $\log p_\theta(x)$, also called the variational lower bound. The variational lower bound can be written in two forms:
1) The jointcontrastive form
2) The priorcontrastive form
The priorcontrastive form can be related to autoencoders: the KL term acts as a regularisation term while the expectation term is the reconstruction error ($z$ is sampled from $x$ by $g(\epsilon,x)$; $x$ is then sampled from $z$ by $p_\theta(xz)$).
1.3. Stochastic Gradient Variational Bayes (SGVB) estimator
In order to optimise the $\mathcal{L}(\theta,\phi;x)$, we need to differentiate it w.r.t. $\theta$ and $\phi$. However, differentiating w.r.t. $\phi$ is problematic as the usual Monte Carlo gradient estimator exhibits high variance^{3}. The usual Monte Carlo method estimates the terms by sampling $z$ from $q_\phi(zx)$:
Instead, we use a reparametrisation trick so that the Monte Carlo estimate becomes differentiable. We reparametrise $z$ as:
where the transformation $g_\phi(\cdot)$ and auxiliary noise $\epsilon$ are chosen based on the form of $q_\phi(zx)$.
Figure 2 How reparametrisation trick changes network structure
Essentially, we are assuming that sampling $z$ from $q_\phi(zx)$ is equivalent to sampling $\epsilon$ from $p(\epsilon)$ (Figure 2). As a result, we can rewrite integrals involving $q_\phi(zx)$ in forms involving $\epsilon$ instead:
We can then compute the SGVB estimator for both forms of variational lower bound:
1) The jointcontrastive form
2) The priorcontrastive form
The priorcontrastive form typically has less variance, but relies on the KL term having closed form to work. In particular, if both the true prior and the approximate posterior are Gaussian, the KL term can be easily computed.
2. Application to biomedical segmentation
2.1. Problem scenario
CNNbased Biomedical segmentation usually requires large set of annotated data and does not take into account of anatomical knowledge. In the Dalca et al. paper^{2}, an twostep AEVB approach is proposed to carry out unsupervised segmentation learning, with anatomical prior incorporated.
Consider a model consisting of image data $x$, its segmentation map $s$, and the latent variable $z$, which represents the underlying shape embedding (Figure 3).
Figure 3 Generative model for the segmentation problem
The model generates values based on the following prior and likelihood densities:
1) the prior probability of the latent variable $z$ is modeled by a standard normal distribution
2) the labels $s$ are drawn from a categorical distribution given $z$
where $f_{j,m}(\cdot; \theta_{sz})$ is the probability of label $m$ at voxel $j$
3) the voxel intensity values $x$ are drawn from normal distributions with diagonal covariance matrix
where $\delta(s[j]=m) = 1$ if $s[j]=m$ and $0$ otherwise.
We aim to learn segmentation from an image using maximum a posteriori (MAP) estimation:
However, $p_\theta(s,x)$ is intractable as it requires integrating over $z$:
Similarly, $p_\theta(zx,s)$ is also intractable, which mean EM algorithm does not work in this case.
Instead, we will optimise $\log p_\theta(x)$ and $\log p_\theta(s)$ by optimising their variational lower bounds respectively. Note that we cannot optimise the two probability distributions together because we do not have paired sets of $x$ and $s$ in unsupervised learning setting. Nevertheless, we want to learn mappings between the latent space and both the image space and the label space, so that we can apply anatomical knowledged learnt from labellatent mapping to new images.
2.2. SGVB estimator derivations
2.2.1. Learning anatomical prior
Using the AEVB framework, we approximate the true posterior $p_\theta(zs)$ with $q_\phi(zs)$. $q_\phi(zs)$ is modeled as a normal distribution with diagnoal covariance matrix:
The respective variational lower bound is:
Using the reparametrisation trick, we can estimate the expectation term as:
Since both the true prior $p(z)$ and the approximate posterior $q_\phi(zs)$ are gaussian, the KL term has a closed form (see Appendix A for the proof):
2.2.2. Unsupervised learning
Similar to the previous section, we approximate the true posterior $p_\theta(zx)$ with $q_\phi(zx)$. $q_\phi(zx)$ is modeled as a normal distribution with diagnoal covariance matrix:
The respective variational lower bound is:
Again, the KL term has a closedform solution:
The expectation term, on the other hand, is intractable as we cannot compute $p_\theta(xz)$. Instead, we will use a lower bound of the term derived based on Jensen’s inequality (more details in Appendix B):
We can then estimate the expectation term, sampling $z$ using the reparametrisation trick:
2.3. Network architecture
Two CNNs are used, one for learning the anatomical prior and one for unsupervised learning of segmentation from image.
Figure 4 shows the autoencoder architecture for learning anatomical prior. This network learns the mapping between the latent space and the label space, using only segmentation maps.
Figure 4 Autoencoder to learn anatomical prior
Given an input segmentation map $s$, the encoder outputs the parameter of the posterior $q_\phi(zs)$, $\mu_{zs}$ and $\sigma_{zs}$. The latent variable $z$ is then sampled using reparametrisation trick, where $z \sim g(\epsilon,s) = \mu_{zs} + \sigma_{zs} \epsilon$. The transformation $g(\cdot)$ approximates normal distribution by adding a scaled noise (or variance) to the mean. Finally, $s$ is reconstructed by the decoder with $p_\theta(sz)$, represented by the top output layer in Figure 4.
An additional prior is added in actual implementation, represented by the bottom output layer in Figure 4. This location prior $p_{loc}(s)$ is computed as the voxelwise frequency of labels in the prior training dataset. By multiplying the prior $p_{loc}(s)$ with the reconstructed segmentation map from $p_\theta(sz)$ (adding the logarithms in actual implementation), the output segmentation is obtained.
The parameters of $q_\phi(zs)$ and $p_\theta(sz)$, as well as the output segmentation labels $s[j]$, are then used to estimate and optimise the SGVB lower bound.
Figure 5 shows the architecture for unsupervised learning of segmentation from input images. The decoder part of this network is copied from the priorlearning network from Figure 4, which helps to incorporate the prior anatomical knowledge learnt.
Figure 5 Architecture for unsupervised learning
The image encoder is first trained in a similar fashion as the prior autoencoder; the pretrained weights are used for initialisation during actual training.
Given an input scan $x$, the image encoder outputs the parameter of the posterior $q_\phi(zx)$, $\mu_{zx}$ and $\sigma_{zx}$. The latent variable $z$ is then sampled using reparametrisation trick, where $z \sim g(\epsilon,x) = \mu_{zx} + \sigma_{zx} \epsilon$. Then, the already trained prior decoder outputs $s$, which is used to generate the reconstructed $x$ with $p_\theta(xs)$.
The parameters of $q_\phi(zx)$ and $p_\theta(xs)$, as well as the reconstructed image $x$, are then used to estimate and optimise the SGVB lower bound.
2.4. Results
The networks are tested on two datasets (Figure 6):
1) T1w scan dataset: more than 14,000 T1weighted MRI scans from ADNI, ABIDE, GSP, etc., all resampled to 256x256x256 (1mm isotropic) and cropped to 160x192x224 to remove entirelybackground voxels.
2) T2FLAIR scan dataset: 3800 T2FLAIR scans from ADNI (5mm slice spacing), linearly registered to T1w images (using ANTs). No ground truth.
Figure 6 Example T1w and T2FLAIR images
A prior training set is composed using 5000 T1w images with ground truth segmentation maps generated by FreeSurfer (with manual correction and QC). This set is used to train the prior autoencoder. The rest of the T1w images are split into trianing, validation and test sets for unsupervised learning. In the T2 scan case, subjects common to both datasets are excluded from the prior training set.
Figure 7 shows segmentation results from 3 subjects in the T1w scan dataset, in comparison to their respective ground truth. Figure 8 shows one subject’s segmentation results from the T2FLAIR scan dataset. The estimated segmentation boundaries are generally aligned with the ground truth boudnaries, but much smoother, lacking of fine details.
Figure 7 Example segmentation results for T1w scan dataset
Figure 8 Example segmentation results for T2FLAIR scan dataset
Appendix
A. Closedform solution for KL divergence
When two distributions are both gaussian, the KL divergence distance between them has a closed form. I will start with the proof for the generic case, and then show the solutions for a special case.
Consider two mutivariate normal distribution of dimension $k$, $p(x) \sim \mathcal{N} (\mu_1, \Sigma_1)$ and $q(x) \sim \mathcal{N} (\mu_2, \Sigma_2)$. There is no constraint on the form of covariance matrix. Since the KL divergence is defined as $KL(p(x)q(x)) = \mathbb{E}_p [\log p(x)  \log q(x)]$ for continuous distributions, we start by expressing log likelihood for the two distributions.
For $p(x)$:
where is the matrix determinant.
Similarly for q(x):
Then we have:
Some tricks used are:
1) trace trick: $\mathbb{E}[x] = \mathbb{E}[\mathrm{Tr}(x)]$ if x is scalar
2) $\mathbb{E}_p [ (x\mu_1)^T \Sigma_2^{1} (\mu_1\mu_2)]$: this is equal to 0. The expression is essentially the expectation of a scaled and rotated version of $(x\mu_1)$, which are still straight lines passing through the origin.
Now, consider a special case, where $p(x) \sim \mathcal{N} (\mu_1, \sigma_1^2)$ and $q(x) \sim \mathcal{N} (0, \mathrm{I})$. In this case, we will be able to simplify the solution even further:
This is the solution used in section 2.2.
B. Jensen’s inequality
Jensen’s inequality states that the secant line of a convex function lies above the graph of the function. The concept is intuitive.
Mathematically, for any convex function $f$, Jensen’s inequality is:
Conversely, for any concave function $g$, the sign is simply flipped:
When probability density functions are involved (e.g. when computing expectation), Jensen’s inequality is:
where $f$ is the convex function, $g$ is any realvalued measurable function, and $p(x)$ is the probability density function (i.e. $\int_p p(x) \mathop{dx} = 1$).
In section 2.2.2, $\log$ is a concave function, and $p_\theta(sz)$ is the probability density function. Therefore:

Diederik P Kingma and Max Welling. “AutoEncoding Variational Bayes”. In: ArXiv eprints (2013). arXiv: 1312.6114 ↩

Adrian V Dalca, John Guttag, and Mert R Sabuncu. “Anatomical Pri ors in Convolutional Networks for Unsupervised Biomedical Segmenta tion”. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018, pp. 9290–9299. ↩ ↩^{2}

David M Blei, Michael I Jordan, and John W Paisley. “Variational Bayesian Inference with Stochastic Search”. In: Proceedings of the 29th International Conference on Machine Learning (ICML12). 2012, pp. 1367– 1374. ↩