Neural Discrete Representation Learning: A Comprehensive Tutorial

Abstract

Learning useful representations without supervision remains a key challenge in machine learning. This article explores a generative model designed to learn discrete representations: the Vector Quantised-Variational AutoEncoder (VQ-VAE). The VQ-VAE differs from traditional VAEs in two key aspects: the encoder network outputs discrete codes instead of continuous ones, and the prior is learned rather than static. By incorporating vector quantization (VQ), the model addresses the issue of "posterior collapse" often observed in VAEs, where latent variables are ignored when paired with a powerful autoregressive decoder.

Introduction to VQ-VAE

The Variational Autoencoder (VAE) is a generative model that compresses an input into a continuous latent vector within a latent space. VQ-VAE learns a discretized latent space, which is intuitively better-suited for discrete data such as images. It achieves this by introducing a powerful PixelCNN prior for the latent space, which is more expressive than the typical Gaussian assumption. Instead of normally-distributed variables, the latent space consists of single-channel images of reduced dimension.

The Main Idea Behind VQ-VAE

The core concept of VQ-VAE revolves around creating a discrete latent space, making it more suitable for handling discrete data types like images. This is achieved by employing a robust PixelCNN prior for the latent space, offering greater expressiveness compared to the standard Gaussian assumption. The latent space, instead of consisting of normally distributed variables, is composed of single-channel images with reduced dimensions.

From Probabilistic to Deterministic

To understand why the VQ-VAE can generate more decisive outcomes than VAE, we can examine how the posterior is constructed for both models.

In VQ-VAE the posterior is deterministic:

Read also: Joint Learning in Machine Translation

In VAE, the posterior is defined as follows:

In definition, we want to find the probability of z given x input, which can be expressed using a gaussian distribution. Next, the decoder uses this information to reconstruct an image x_hat given the latent vector z. It is worth noting that there are infinitely possible z that can be sampled from this gaussian distribution (see again the posterior expression

To maximize the probability of this reconstruction we can define this as log likelihood

Since we have many possibilities of z, then we have to express the reconstruction loss in the form of expectations:

We want to minimize the negative log-likelihood during the training, thus, the training objective can be defined as

Read also: Explore the Influence of IEEE Transactions on Neural Networks

In essence, this expectation tells us that what is the chance of constructing an image (x^hat) given all possible latent vector z derived from the posterior. This is where the problem comes, in VAE the posterior is drawn from the gaussian distribution, thus for the same input the decoder must take an average across all these possible latent codes.

On the other hand, the VQ-VAE defines the reconstruction loss as follows:

There is only one possible latent vector z, thus the reconstruction does not need expectation like what is defined in VAE

Addressing Limitations of Continuous Latent Representations

A Variational autoencoder (VAE) is commonly known as a generative model that compresses an input into a continuous latent vector in a latent space. Regardless of its popularity, the VAE addresses an issue with the concept of continuous latent representation. In some real-world applications, a dataset may have contrast classes, as the latent vector is continuous. Whenever a model needs to generate between two data points that are closely related, the model has to interpolate between these 2 data points to produce a new output. Take an example when we want to interpolate between an image of a class of cat and a person. The result produces an unreasonable output as follows:

To avoid this, we want to make the latent space representation as a discrete latent space instead of continuous [1]. The way we achieve this is by embedding the encoder vector in an embedding space called the codebook. This codebook consists of a finite set of learned vectors, and each of these vectors is mapped to the encoder outputs that are continuous vectors (can take infinitely many possibilities).

Read also: AI-powered video communication

Vector Quantization (VQ) in VQ-VAE

Vector quantization (VQ) is a data compression technique, similar to the k-means clustering algorithm, that models the probability density function of data using representative vectors known as codebooks. It has found applications in various machine learning areas, especially in vector quantized variational autoencoders (VQ-VAE). Vector quantization is the operation of mapping all data points (in a voronoi-cell) to its nearest codebook point. Vector quantization cannot be used directly in machine learning-based optimizations, since there is no real gradients defined for it (for the above-mentioned equation).

VQ-VAE Architecture

The VQ-VAE architecture typically comprises three main components:

  1. Encoder: This network, denoted as q_theta(z|x), maps the input data (x) to a continuous latent representation.
  2. Codebook: This contains a set of learned embedding vectors. The encoder output is then mapped to the nearest embedding vector in the codebook.
  3. Decoder: This network, denoted as p_phi(x|z), reconstructs the input from the quantized latent representation.

The Role of the Codebook

To avoid this, we want to make the latent space representation as a discrete latent space instead of continuous [1]. The way we achieve this is by embedding the encoder vector in an embedding space called the codebook. This codebook consists of a finite set of learned vectors, and each of these vectors is mapped to the encoder outputs that are continuous vectors (can take infinitely many possibilities).

The latent space in our example will be of size one, that is, we have a single embedding vector representing the latent code for each input sample. Now, each of these 16d vectors needs to be mapped to the embedding vector it is closest to. This model serves two purposes: First, it acts as a store for the embedding vectors. Here, the current state of embeddings is stored in codebook.

Loss Functions in VQ-VAE

The loss function in VQ-VAE is designed to optimize both the reconstruction quality and the codebook usage. The overall loss typically consists of three terms:

  1. Reconstruction Loss: Measures the difference between the input and the reconstructed output.
  2. Codebook Loss: Encourages the codebook vectors to be close to the encoder outputs.
  3. Commitment Loss: Encourages the encoder to commit to using the codebook vectors.

KL Divergence Loss: Ignored

In order to compute the KL divergence loss we need 2 things: Prior and Posterior

  • Prior. In a standard VAE the prior is a gaussian distribution. But in VQ-VAE the prior where k is the total number of embedding vectors
  • Posterior. Since the process of quantization is deterministic, the posterior distribution is categorical with probability of 1 for the nearest codeword and 0 for the others.

if we substitute these prior and posterior to our KL divergence formula then we have KL divergence as equal to log K, we can ignore this since this is just a constant

Codebook and Commitment Loss

During the back propagation, the gradient flows back very well from the output to the decoder

But when it hits the codebook part the gradient becomes zero. Why? The problem with a deterministic quantization is that this gives us a non-differentiable function. To overcome this we can use an approach called Straight Through Estimation (STE) . This STE assumes that the gradient flows to the decoder is equal to the gradient flows to the encoder.

And this is why the codebook loss and commitment loss are introduced. Additionally these losses are required because the codebook is randomly initialized. Thus, it is possible that the encoder outputs are associated with only a few codebook entries as depicted in Figure below

Straight-Through Estimator (STE)

VQ-VAE was first introduced in the paper "Neural Discrete representation Learning" by Oord et al [6]. They apply VQ to model a discrete representation of the latent space of variational autoencoder. In other words, they apply VQ on the output of the encoder and find the best embeddings (codebooks) for that, and then pass these embeddings to the decoder. However, straight through estimator (STE) does not consider the influence of quantization and leads to a mismatch between the gradient and true behavior of the vector quantization. In addition, for the methods which use STE, it is essential to add an additional loss term to the global loss function to make the VQ codebooks (embeddings) to be updated. Therefore, the weighting coefficient for the additional loss term is a new hyper-parameter, which is required to be tuned manually.

NSVQ: An Alternative to STE

In this post, we want to introduce our recently proposed vector quantization technique for machine learning-based approaches, which is published under the title of "NSVQ: Noise Substitution in Vector Quantization for Machine Learning" [8]. NSVQ is a technique in which the vector quantization error is simulated by adding noise to the input vector, such that the simulated noise would gain the shape of original VQ error distribution. Based on the results in our paper [8] (which are also provided in the following figures), it has been shown that not only NSVQ can pass the gradients in the backward pass, but it also can provide more accurate gradients for codebooks than STE. Furthermore, NSVQ leads to higher quantization accuracy, and faster convergence compared to STE. In addition, NSVQ performs more deterministically (shows less variance in performance) when running an individual experiment for several times. One great benefit of using NSVQ is that it does not need any addtional loss term to be added to the global loss function.

Sampling from the Latent Space

Sampling from latent space zz¶Sampling from VQ VAEs is more subtle than normal sampling schemes. To elucidate this point we’ll present three sampling schemes:

  • Uniform sampling
  • Categorical sampling from a histogram
  • Sampling with an autoregressive PixelCNN

Sampling scheme (1) produces scrambled results, because only a small percentage of all possible representations are actually utilized. If you sample uniformly, you will most likely get results outside of the data distribution, which is why they appear random.

Sampling scheme (2) collects a histogram of representations and samples from the histogram. This works but is also incorrect because it limits us to only representations seen during construction of the histogram. Note that if Nhistogram→∞N_{\text{histogram}} \rightarrow \infty then this scheme will work, but since we can’t actually do this, we want a better way to approximate the disribution of our representation space.

Scheme (3) is more complex but provides a principled way of sampling from the latent space.

Applications of VQ-VAE

VQ-VAE has found success in various applications, including:

  • Image generation: VQ-VAE can generate high-fidelity images with diverse content. For instance, Stable Diffusion (built on Latent Diffusion Models) uses vector quantization based on the VQ-VAE framework to first learn a lower-dimensional representation that is perceptually equivalent to the data space. At a high level, the perceptual compression stage (with VQ-VAE) removes high-frequency details but still learns the underlying semantic content. Then, the generative model (diffusion, which is essentially a bunch of noising/denoising steps, in the latent space) learns the semantic and conceptual composition of the data.
  • Audio generation: OpenAI’s Jukebox model, which generates original songs from raw audio, trains a VQ-VAE on audio and then uses a transformer trained on the latents to generate new audio samples.
  • Text-to-image synthesis: OpenAI’s DALL-E involves a transformer that takes as input both the embeddings from text and the latent codes from a VQ-VAE trained on images.
  • Speech coding: VQ-VAE has been explored for low bit-rate speech coding.

Advantages of VQ-VAE

  • Discrete latent space: More suitable for discrete data and avoids issues with continuous latent representations.
  • Addresses posterior collapse: The VQ method helps prevent the decoder from ignoring the latent variables.
  • Learned prior: The model learns the prior distribution of the latent space, allowing for more flexibility.

Limitations of VQ-VAE

  • The weighting coefficient for the additional loss term is a new hyper-parameter, which is required to be tuned manually.
  • Straight through estimator (STE) does not consider the influence of quantization and leads to a mismatch between the gradient and true behavior of the vector quantization.

tags: #neural #discrete #representation #learning #tutorial

Popular posts: