Estimating the Intrinsic Dimension of Protein Sequence Embeddings using ESM-2

Community Article Published October 18, 2023

TLDR: In this post, we show how to estimate the intrinsic dimension of embeddings obtained from a protein language model (ESM-2), using a technique related to persistent homology. We then discuss how this can be used in curriculum learning for new protein language models.

image/png

It's crucial to understand the underlying structure and complexity of datasets. One way to measure this is by estimating the intrinsic dimension of the data. This has been linked to generalization in neural networks, for example in Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks and in Implicit Regularization in Deep Learning May Not Be Explainable by Norms, and The geometry of hidden representations of large transformer models.

Some of this work leads one to consider papers such as Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning which serves as inspiration for the now widely adopted LoRA technique for finetuning models in a parameter efficient way. In particular, we can view LoRA as an important regularization technique which can actually improve generalization of a model and provide better performace on unseen data. But this is all a bit tangential to our goal here.

In this post, we'll discuss how to estimate the intrinsic dimension of protein sequences using a unique method based on embeddings of a protein language model and persistent homology. In particular, we will look at embeddings of protein language models like ESM-2, and we will discuss how to estimate the intrinsic dimension of the embeddings using minimal spanning trees and linear regression. This is an esimation of the persistent homology dimension, a kind of fractal dimension. This will give us a measure of complexity of individual proteins which may then be used in curriculum learning when training a new pLM.

1. What is Intrinsic Dimension?

The intrinsic dimension of a dataset is a measure of the number of parameters required to describe the data or, in simpler terms, the minimal number of coordinates to represent the data without much loss. It gives us an insight into the "complexity" of the data.

For instance, even if a dataset is in a high-dimensional space, its intrinsic dimension might be low if the data points lie close to a subspace (like a plane or a curve).

Mathematically, if we have points sampled from some manifold M M embedded in RD \mathbb{R}^D , the intrinsic dimension is the minimal dimension of M M .

2. Token Embeddings

Imports

import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.spatial import distance_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from transformers import AutoTokenizer, AutoModel
import torch

Embeddings

To estimate the intrinsic dimension of a protein sequence, the first step is to compute embeddings for each token (amino acid) in the sequence. Embeddings are dense vectors that capture the contextual information of tokens.

def get_embeddings(text, model_name="facebook/esm2_t6_8M_UR50D"):
    """
    Compute embeddings for each token in the text using a specified model.
    
    Parameters:
    - text (str): The input text for which embeddings need to be computed.
    - model_name (str): The path to the pretrained model.
    
    Returns:
    - numpy.ndarray: A matrix where each row is the embedding of a token in the text.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
    with torch.no_grad():
        outputs = model(**inputs)

    # Return embeddings after removing <cls> and <eos> tokens and converting to numpy.
    return outputs.last_hidden_state[:, 1:-1, :].squeeze(0).numpy()

Here, we're using a transformer model to obtain the embeddings. The model takes a protein sequence, tokenizes it, and returns embeddings for each token. We ignore the special <cls> and <eos> tokens that are typically used in transformer architectures.

3. Persistent Scores

Once we have embeddings for each token, we use a concept from topological data analysis to compute a persistent score for a subset of these embeddings. The persistent score is based on the sum of edge weights in the Minimum Spanning Tree (MST) formed from the distance matrix of the embeddings:

def compute_persistent_score(embeddings):
    """
    Compute the persistent score for a subset of embeddings using the sum of edge weights in the MST.
    
    Parameters:
    - embeddings (numpy.ndarray): A matrix where each row is an embedding.
    
    Returns:
    - float: The persistent score for the embeddings.
    """
    dist_matrix = distance_matrix(embeddings, embeddings)
    mst = minimum_spanning_tree(dist_matrix)
    return mst.sum()

A Minimum Spanning Tree is a subset of the edges of a connected, edge-weighted graph that connects all the vertices together, without any cycles and with the minimum possible total edge weight. The sum of these weights serves as a measure of the "spread" or "complexity" of the embeddings in their space.

4. Sampling and Scoring

Given the embeddings, the next step involves sampling subsets of tokens and computing their persistent scores. This is done repeatedly for various sample sizes:

def sample_and_score(embeddings, n, k=8, hat_n=40, J=7):
    """
    For various sample sizes, compute the median persistent score across J samples.
    
    Parameters:
    - embeddings (numpy.ndarray): A matrix where each row is an embedding.
    - n (int): Total number of embeddings.
    - k (int): Number of different sample sizes.
    - hat_n (int): A parameter for determining sample sizes.
    - J (int): Number of samples for each sample size.
    
    Returns:
    - list: List of sample sizes.
    - list: List of corresponding median persistent scores.
    """
    scores = []
    sizes = [(i - 1) * (n - hat_n) // k + hat_n for i in range(1, k + 1)]
    
    for size in sizes:
        subset_scores = [compute_persistent_score(embeddings[np.random.choice(n, size, replace=False)])
                         for _ in range(J)]
        scores.append(np.median(subset_scores))
    
    return sizes, scores

For each sample size, we randomly choose a subset of tokens, compute the persistent score, and then repeat this J J times to get a median score for that size.

5. Estimating Dimension

The crux of the approach lies in the relationship between the sample sizes and their corresponding persistent scores. On a log-log scale, the relationship is approximately linear for many datasets. The slope of the line of best fit gives us an estimate of the intrinsic dimension:

dimension=11slope\text{dimension} = \frac{1}{1 - \text{slope}}

Here's how it's done:

def estimate_dimension(sizes, scores):
    """
    Estimate the intrinsic dimension of the data using linear regression on log-transformed sizes and scores.
    
    Parameters:
    - sizes (list): List of sample sizes.
    - scores (list): List of corresponding median persistent scores.
    
    Returns:
    - float: Estimated dimension of the data.
    """
    log_sizes = np.log(sizes).reshape(-1, 1)
    log_scores = np.log(scores)

    reg = LinearRegression().fit(log_sizes, log_scores)
    slope = reg.coef_[0]
    
    return 1 / (1 - slope)

6. Bringing It All Together

Finally, to get a robust estimate of the intrinsic dimension, we repeat the sampling, scoring, and dimension estimation several times and average the results:

def estimate_sequence_dimension(text, runs=5):
    """
    Estimate the intrinsic dimension of the text by repeatedly sampling subsets of its tokens, 
    computing their persistent scores, and then using linear regression on the log-transformed values.
    
    Parameters:
    - text (str): The input text for which the dimension needs to be estimated.
    - runs (int): Number of runs with different random seeds.
    
    Returns:
    - float: Estimated dimension of the text.
    """
    embeddings = get_embeddings(text)
    n = embeddings.shape[0]
    
    slopes = []
    for _ in range(runs):
        sizes, scores = sample_and_score(embeddings, n)
        log_sizes = np.log(sizes).reshape(-1, 1)
        log_scores = np.log(scores)
        
        reg = LinearRegression().fit(log_sizes, log_scores)
        slopes.append(reg.coef_[0])
    
    kappa_F = np.mean(slopes)
    return 1 / (1 - kappa_F)

When applied to a protein sequence:

text = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
dimension = estimate_sequence_dimension(text)
print(f"Estimated dimension of the protein sequence: {dimension}")

which prints:

Estimated dimension of the protein sequence: 13.063370658316673

The output provides an estimated intrinsic dimension for the protein sequence.

Other Methods (Scikit-Dimension)

Various other methods for estimating the intrinsic dimension of data exist, such as "MLE" and "TwoNN", and depending on your use case you may wish to consider these methods as well. Let's have a look at the various methods used in scikit-dimension and see if any are close to our method. First, you will need to run

pip install scikit-dimension

Then, be sure to import it:

import skdim

Next, we can use the following function to print the intrinsic dimension as estimated by the various methods available in skdim:

methods = {
    "corr_int": skdim.id.CorrInt,
    "danco": skdim.id.DANCo,
    "ess": skdim.id.ESS,
    "fisher_s": skdim.id.FisherS,
    "knn": skdim.id.KNN,
    "lpca": skdim.id.lPCA,
    "mada": skdim.id.MADA,
    "mind_ml": skdim.id.MiND_ML,
    "mle": skdim.id.MLE,
    "mom": skdim.id.MOM,
    "tle": skdim.id.TLE,
    "twonn": skdim.id.TwoNN
}

def estimate_dimension(embeddings, method="twonn"):
    """
    Estimate the intrinsic dimension of embeddings using the specified method.
    
    Parameters:
    - embeddings (numpy.ndarray): A matrix where each row is an embedding.
    - method (str): The method to use for dimension estimation.
    
    Returns:
    - float: The estimated intrinsic dimension.
    """
    
    if method not in methods:
        raise ValueError(f"Unknown method: {method}")
    
    id_est = methods[method]().fit(embeddings)
    return id_est.dimension_

Then, you can use this as follows:

# Example usage:
text = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
embeddings = get_embeddings(text)

for method in methods.keys():
    dimension_estimate = estimate_dimension(embeddings, method=method)
    print(f"Estimated Intrinsic Dimension ({method.upper()}): {dimension_estimate}")

This will print:

Estimated Intrinsic Dimension (CORR_INT): 7.599677519235372
Estimated Intrinsic Dimension (DANCO): 310.95695219616096
Estimated Intrinsic Dimension (ESS): 45.757353328926165
Estimated Intrinsic Dimension (FISHER_S): 11.925543780836733
Estimated Intrinsic Dimension (KNN): 6
Estimated Intrinsic Dimension (LPCA): 48
Estimated Intrinsic Dimension (MADA): 15.526202715518686
Estimated Intrinsic Dimension (MIND_ML): 10.0
Estimated Intrinsic Dimension (MLE): 11.85228294928126
Estimated Intrinsic Dimension (MOM): 4.662291966147815
Estimated Intrinsic Dimension (TLE): 11.681521116520777
Estimated Intrinsic Dimension (TWONN): 11.715313108714346

Here, we can see that most of the methods available in skdim provide estimates close to our persistent homology dimension estimate of 13.063370658316673. We should note at this point that the persistent homology method is stochastic, which is why we average over several iterations. So, you will not get the exact same answer every time unless you use the same random seeds.

Conclusion and Some Use Cases

Understanding the intrinsic dimension of data, including protein sequences, can be invaluable for various analytical and computational tasks. It gives insights into the underlying structure of the data, which can then be used for optimization, visualization, and more. The method discussed here is especially intriguing as it combines advanced techniques from machine learning and topological data analysis (persistent homology).

What is this useful for? Well, the method was originally developed as an attempt to detect AI generated text in the article Intrinsic Dimension Estimation for Robust Detection of AI-Generated Texts. So, in particular, this might be used to detect AI generated proteins. However, there are more interesting applications. For example, what if we train a protein language model on proteins with lower intrinsic dimension first, and then on proteins with progressively higher intrinsic dimension? This provides a form of curriculum learning where the model learns on easier data first, and then on progressively harder data.

Integrating intrinsic dimension into LLMs via curriculum learning might also prove useful as well, and this could easily be transferred to the NLP domain. This would likely increase the intrinsic dimension of AI generated text, making the text less distinguishable from human generated texts. Moreover, the same idea might hold for protein language models as well.

As a final note, we also recommend reading Bridging Information-Theoretic and Geometric Compression in Language Models, Topological Singularity Detection at Multiple Scales, The geometry of hidden representations of large transformer models, and The geometry of hidden representations of protein language models. The last two seem to imply that choosing the rank for LoRA should be done per layer, and should roughly match the intrinsic dimension for that layer. Also, recent research I have worked on shows LoRA and QLoRA can be effective regularization techniques, improving generalization to unseen data and dramatically reducing overfitting. If regularization is needed due to overfitting, choosing a rank for the LoRA that is lower than the intrinsic dimension for that layer is likely helpful.