Matryoshka Representation Learning: A Comprehensive Guide to Multi-Scale Embeddings

In the ever-evolving landscape of machine learning, the ability to strike a balance between model performance and computational cost is paramount. Matryoshka Representation Learning (MRL) emerges as a compelling solution, offering a unique approach to representation learning that encodes information at multiple granularities within a single embedding. This article delves into the intricacies of MRL, exploring its underlying principles, training methodologies, applications, and benefits.

Introduction: The Need for Flexible Representations

Learned representations are essential for extracting useful features from data, making downstream tasks significantly easier. Traditional representation learning methods typically produce fixed-size embeddings, which can be limiting in various scenarios. While high-dimensional representations offer greater expressiveness, they also impose a significant computational overhead, especially for tasks that do not require such high fidelity. This motivates the need for a flexible representation architecture that can perform feature mappings across a range of dimensions, potentially sacrificing some accuracy on downstream tasks.

MRL addresses this challenge by enabling the creation of multi-scale representations within a single model. Inspired by Matryoshka dolls, where smaller dolls fit inside larger ones, MRL allows a model to output representations of varying sizes, from coarse to fine, using a single forward pass. This approach is particularly effective for tasks like semantic search, information retrieval, multilingual processing, and any application requiring nuanced representations of data across different levels of abstraction.

Understanding Embeddings

Embeddings are numerical representations of complex objects, such as text, images, or audio. They serve as a versatile tool in natural language processing and other fields, enabling practitioners to solve a wide variety of tasks. An embedding model typically produces embeddings of a fixed size. However, as research has progressed, state-of-the-art embedding models have started producing embeddings with increasingly higher output dimensions, representing each input with more values.

Matryoshka Embeddings: Nested Information

Matryoshka embedding models address the challenge of high-dimensional embeddings by storing more critical information in the earlier dimensions and less important information in the later dimensions. These models are trained to ensure that smaller, truncated embeddings remain useful. The name "Matryoshka" is inspired by the Russian nesting dolls, reflecting the nested structure of the embeddings.

Read also: Overview of Masked Contrastive Learning

Why Use Matryoshka Embedding Models?

Matryoshka embedding models offer several advantages, including:

  • Cost-Performance Trade-off: MRL allows for a trade-off between a model's performance and its cost. By adjusting embedding dimensions, it helps strike a balance between cost efficiency and model performance.
  • Efficient Shortlisting and Reranking: MRL facilitates efficient shortlisting and reranking in information retrieval tasks. Smaller embeddings can be used to quickly shortlist a set of candidate documents, while larger embeddings can be used to rerank the shortlisted candidates for improved precision.
  • Reduced Memory Footprint: Smaller embeddings require less memory for storage, which can be crucial in resource-constrained environments.
  • Faster Retrieval: Smaller embeddings can speed up the retrieval process by reducing the computational complexity of similarity searches.
  • Adaptability: MRL allows a single embedding to adapt to the computational constraints of downstream tasks.

How are Matryoshka Embedding Models Trained?

The MRL approach can be adopted for almost all embedding model training frameworks. In a standard training step for an embedding model, embeddings are produced for a training batch, and a loss function is used to create a loss value that represents the quality of the produced embeddings. For Matryoshka embedding models, the training step also involves producing embeddings for the training batch. However, the loss function is used to determine not just the quality of the full-size embeddings but also the quality of the embeddings at various dimensionalities. The loss values for each dimensionality are then aggregated, resulting in a final loss value. This incentivizes the model to frontload the most important information at the start of an embedding, such that it will be retained if the embedding is truncated.

Mathematical Formulation

Given a labeled dataset D = {(x1, y1), . . . , (xN, yN)}, where xi ∈ X is an input point and yi ∈ [L] represents the label of xi, MRL optimizes the multi-class classification loss for each nested dimension, m ∈ M, using standard empirical risk minimization. This is achieved by employing a separate linear classifier, parameterized by W(m) ∈ RL×m, for each dimension. The losses obtained from these classifiers are then aggregated, taking into account their relative importance (cm ≥ 0) m∈M. That is, the following is optimized:

min {W(m)}m∈M, θF N−1 Σ i∈[N] Σ m∈M cm ⋅ L(W(m) ⋅ F(xi; θF)1:m ; yi)

where L is the multiclass softmax cross-entropy loss function.

Weight Tying

A natural way to make this efficient is through weight-tying across all the linear classifiers, i.e., by defining W(m) = W1:m for a set of common weights W. This would reduce the memory cost due to the linear classifiers by almost half, which would be crucial in cases of extremely large output spaces.

Read also: A Guide to CPC

How to Use Matryoshka Embedding Models

In practice, getting embeddings from a Matryoshka embedding model works the same way as with a normal embedding model. The only difference is that, after receiving the embeddings, we can optionally truncate them to a smaller dimensionality. After truncating, you can either directly apply them for your use cases or store them such that they can be used later.

Implementation Example

Here's an example of how to use a Matryoshka embedding model with the SentenceTransformers library:

from sentence_transformers import SentenceTransformermatryoshka_dim_short = 64matryoshka_dim_long = 768text = ["The weather is so nice!"]short_embedding = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka", truncate_dim=matryoshka_dim_short).encode(text)long_embedding = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka", truncate_dim=matryoshka_dim_long).encode(text)print(f"Shape: {short_embedding.shape, long_embedding.shape}")print(short_embedding[0][0:10])print(long_embedding[0][0:10])

Applications of MRL

MRL has a wide range of applications, including:

  • Adaptive Classification: MRL can be used for adaptive classification, where the model starts with a coarser representation and moves to a finer representation only if it is not confident of its predictions.
  • Adaptive Retrieval: MRL can be used for adaptive retrieval, where the model first retrieves a larger quantity of vectors using a lower-dimensional representation and then re-ranks the top ones using a higher dimension.
  • Shortlisting and Reranking: MRL is well-suited for shortlisting and reranking tasks, where smaller embeddings are used for initial retrieval and larger embeddings are used for refining the results.
  • Funnel Retrieval: Funnel retrieval is a technique that halves the shortlist size and doubles the representation size at every step of re-ranking.

Shortlisting and Reranking in Detail

As the name suggests, this method is broken down into two main parts. The first one is:

  • Shortlisting: Shortlisting is the process of using small-sized embeddings to retrieve a few top relevant documents very quickly. For example, retrieving 200 relevant documents from 1000 documents. Here, we trade off accuracy for speed and efficiency.
  • Reranking: Here, we use full-sized embeddings to rerank the existing 200 documents. For example, D1, D2, D3 can become D2, D3, D1 after reranking. Reranking does not affect our efficiency, as we only perform it on a small number of documents compared to the original number of documents.

So, in short:

Read also: Inductive Representation Learning

First we use MRL to create varying embeddings. Then, we use smaller embeddings to retrieve a few top documents related to the query over a large set of documents. After that, we rerank or rearrange the retrieved documents. Thus, we get the most relevant and important documents for the query.

MRL Experimental Results

The MRL approach has been evaluated on machine learning models across different modalities, including text, vision, and vision-text. ResNet50 and ViT models represent vision-based models, BERT represents text-based models, and ALIGN represents the combination of vision and text. These models have primarily been assessed for two common use cases: classification and retrieval.

Classification

When comparing the performance of the ResNet50 model trained on the ImageNet-1K dataset to an independently trained standard ResNet50, the MRL model achieves comparable top-1 accuracy at various feature representation sizes. The accuracy of 1-nearest neighbor (1-NN) for each size of feature representation was also measured. With this setup, the ResNet50 trained with MRL is up to 2% more accurate at each feature representation size compared to its fixed-feature counterpart.

The performance of the ViT model trained with MRL on the JFT-300M dataset is also very competitive across all representation sizes. Its 1-NN accuracy is comparable to that of ViT trained with fixed-size feature representations. The performance of the MRL model is also better than its fixed-size counterpart at lower feature representations, partly because random features are selected from the fixed-size model to represent lower dimensions. A similar trend is observed with the ALIGN model trained using the MRL approach; its performance matches that of the ALIGN model trained with fixed-size representations.

Retrieval

The retrieval quality between ResNet50 trained with MRL and fixed-size feature representations using mean Average Precision (mAP) was tested. The MRL model shows an improvement in mAP of up to 3% compared to its fixed-size counterpart at every feature representation size.

In adaptive retrieval, using the MRL model can theoretically speed up the retrieval process by up to 128 times. In real-world applications, where approximate nearest neighbor (ANN) algorithms are commonly implemented, the setup using the MRL model achieves a 14-fold speedup compared to retrieval processes using the HNSW algorithm on identical hardware.

Why Does MRL Work?

MRL learns a coarse-to-fine hierarchy of nested subspaces and efficiently packs information in all vector spaces while being explicitly trained only at logarithmic granularities. The nested loss enforces a vector subspace structure where each learned representation vector lies in a d-dimensional vector space that is a subspace of a larger vector space.

Normalization and Quantization

In addition to slicing the embeddings, normalization and quantization can further improve the performance and efficiency of MRL.

  • Normalization: L2-normalize embeddings so that a 64-, 128-, or 256-dim slice sits in a stable geometry. This tightens precision and pulls in items that were previously missed. It also increases separation where it mattered, pushing clearly unrelated pairs further apart.
  • Quantization: After normalization, apply lightweight quantization to reduce storage costs without a noticeable quality hit.

Current Limitations

The nested losses used to train MRL at different dimensional levels are treated equally, which may not always result in the best balance between accuracy and efficiency.

tags: #matryoshka #representation #learning #explained

Popular posts: