Combination of VAE and GAN
In this article, I’ll be explaining about Adversarial Auto Encoder (AAE), a hybrid between VAE and GAN for generative modelling. Before reading this, I recommend you to read my previous article about Variational Autoencoder (VAE) as I will assume the readers have the knowledge of VAE.
Generative Adversarial Network or GAN is one of the approaches in deep generative modelling that is becoming really popular right now. The main difference between GAN and VAE is that GAN tries to match the pixel level distribution instead of the data distribution and the way it optimizes the model distribution to the true distribution.
How does GAN generate an image? Images can be thought of as just vectors of the pixel value. However, not any random values can be used to have an image of an object. An image of a dog needs to have a certain value for the pixels and they are arranged in a certain way to make it look like a dog. Hence, we can say that the vectors need to follow a certain distribution that could make it look like a dog. Therefore, the purpose of GAN is to take a random vector as an input and transform it to follow the pixel distribution of our desired output.
Therefore, the latent variables or the inputs of GAN itself have no meaning except the meaning given by the network. The model will try to map any point in the latent space to a meaningful output.
How to Train GAN?
Although, we know that GAN transforms the input to follow the object distribution, how does GAN actually optimize the network to learn the output distribution? There are “direct” and “indirect” approaches to do this. The “direct” approach is by comparing the true distribution with the generated distribution, calculating the errors, and optimizing the networks accordingly. This is the approach that Generative Matching Network (GMN) used. However, the true distribution of the output is likely to be complex. Unlike, gaussian distribution where we can simply describe it with mean and the variance. It would be difficult to actually express both the true and generated distribution explicitly. Instead, the distributions are compared based on the samples of the true distribution and the generated ones. Based on the samples of the true and generated data, we can approximate the distribution and compare the difference.
On the other hand, GAN follows the “indirect” approach where it has another network, the discriminator, whose task is to classify the real and generated samples. In short, GAN architecture has 2 components, the discriminator which takes both the true and generated samples and trained to classify them, and the generator which takes a random vector to generate a fake sample and trained to fool the discriminator to classify the fake samples as real.
The steps in training GAN are:
- Training Discriminator
- Freeze the generator’s weight (only update discriminator)
- Generate fake samples using the generator (initially, this will be noise as the generator is untrained at first)
- Samples the real samples
- Train discriminator using the real and fake samples (using the real and fake label)
2. Training Generator
- Freeze the discriminator’s weight (only update generator)
- Generate fake samples using the generator
- Train the generator using the discriminator output with the fake samples as the input.
3. Repeat
I recommend you to read Joseph Rocca article to understand more on GAN. I love his illustration.
Adversarial Autoencoder (AAE) is a clever idea of blending the autoencoder architecture with the adversarial loss concept introduced by GAN. It uses a similar concept with Variational Autoencoder (VAE) except that it uses adversarial loss to regularize the latent code instead of the KL-divergence that VAE uses.
In VAE, KL-divergence (the difference between distribution) is used to match the encoded latent code into a normal distribution (or any arbitrary distribution that was chosen). AAE replaces this with adversarial loss where an additional discriminator component is added and the encoder will act as the generator. Unlike GAN, where the output of the generator is the generated image and the input for the discriminator are both the real and fake images, AAE’s generator generates a latent code and tries to fool the discriminator into believing that the latent code is sampled from the chosen distribution. On the other hand, the discriminator will predict whether a given latent code is generated by the autoencoder (fake) or a random vector sampled from the normal distribution (real).
There are three options of the type of encoder that you can use:
- Deterministic, this is the same encoder used in autoencoder, the encoder will try to compress the input into certain features represented as vector z.
- Gaussian Posterior, this is the same encoder used in VAE, instead of encoding it into a single value for each feature, the encoder will store the gaussian distribution of each feature with 2 variables, mean and variance.
- Universal Approximator Posterior, this also encodes the features as distribution. Except that, we do not assume the feature distribution to follow a gaussian distribution. In this case, the encoder will be a function f(x, n), where x is the input and n is a random noise that follows any arbitrary distribution.
Therefore, AAE architecture consists of these components:
- Encoder, the encoder will take the input and transform it into a lower dimension (latent code z)
- Decoder, the decoder will take the latent code z and transform it into the generated image.
- Discriminator, the discriminator takes random vector z sampled from the chosen distribution (real) and also the encoded latent code z (fake) from the autoencoder as the input. It will check whether the input is real or not.
From the architecture above, you can see the 2 main differences between AAE and GAN are the encoder and the discriminator. AAE takes an image as an input instead of a random vector z like GAN. This is done by adding an encoder at the start. Additionally, AAE also tries to make the latent code to follow the normal distribution (or your chosen distribution). This is done by changing the discriminator tasks to predict whether a latent code z is taken from a normal distribution or generated by the autoencoder. Unlike GAN, where the discriminator’s task is to predict whether a given image is real or generated (fake).
AAE is very similar to VAE except for the regularizing term on the latent code. But, which one is actually better in matching the data distribution? In the AAE paper, a comparison study was done on MNIST.
From the result above, we can see that AAE is superior in producing a clear latent space when the latent code is trained to match a 2D gaussian distribution and a mixture of 10 2D Gaussian distribution on the MNIST dataset. Additionally, given label information, AAE is able to better map the 2D-Gaussian distribution and even more complex distribution such as swiss roll distribution.
Also, the latent space also manages to represent the style consistently. If we walk along the swiss roll axis, the middle part of each mixture represents an upright writing style, while the right and left part represents a tilted writing style.
One of the exciting applications of AAE is in the anomaly detection and localization tasks which I am currently researching. The main challenge in anomaly detection is the lack of anomaly data. Hence, an unsupervised method to detect the anomaly is needed. In this case, an autoencoder can be trained to reconstruct an anomaly image to a normal image. Then, an anomaly can be detected by calculating the difference between the reconstructed image without the anomaly and the original anomaly image.
With AAE, the performance of the autoencoder can be improved with the adversarial loss. Few examples of the methods in anomaly detection that use some elements of AAE are GANomaly, Skip-GANomaly, AnoGAN, and CAVGA They are the current state of the Arts (SOTA).
[1] Makhzani, A., Shlens, J., Jaitly, N., Goodfellow, I., & Frey, B. (2015). Adversarial autoencoders. arXiv preprint arXiv:1511.05644.
[2] https://towardsdatascience.com/paper-summary-adversarial-autoencoders-f89bfa221e48
[4] https://towardsdatascience.com/understanding-generative-adversarial-networks-gans-cd6e4651a29