Nyströmformer: Approximating self-attention in linear time and memory via the Nyström method

Introduction

Transformers have exhibited remarkable performance on various Natural Language Processing and Computer Vision tasks. Their success can be attributed to the self-attention mechanism, which captures the pairwise interactions between all the tokens in an input. However, the standard self-attention mechanism has a time and memory complexity of O(n2)O(n^2) (where nn is the length of the input sequence), making it expensive to train on long input sequences.

The Nyströmformer is one of many efficient Transformer models that approximates standard self-attention with O(n)O(n) complexity. Nyströmformer exhibits competitive performance on various downstream NLP and CV tasks while improving upon the efficiency of standard self-attention. The aim of this blog post is to give readers an overview of the Nyström method and how it can be adapted to approximate self-attention.

Nyström method for matrix approximation

At the heart of Nyströmformer is the Nyström method for matrix approximation. It allows us to approximate a matrix by sampling some of its rows and columns. Let's consider a matrix Pn×nP^{n \times n}, which is expensive to compute in its entirety. So, instead, we approximate it using the Nyström method. We start by sampling mm rows and columns from PP. We can then arrange the sampled rows and columns as follows:

Representing P as a block matrix

We now have four submatrices: AP,BP,FP,A_P, B_P, F_P, and CPC_P, with sizes m×m,m×(nm),(nm)×mm \times m, m \times (n - m), (n - m) \times m and (nm)×(nm)(n - m) \times (n - m) respectively. The mm sampled columns are contained in APA_P and FPF_P, whereas the mm sampled rows are contained in APA_P and BPB_P. So, the entries of AP,BP,A_P, B_P, and FPF_P are known to us, and we will estimate CPC_P. According to the Nyström method, CPC_P is given by:

CP=FPAP+BPC_P = F_P A_P^+ B_P

Here, ++ denotes the Moore-Penrose inverse (or pseudoinverse). Thus, the Nyström approximation of P,P^P, \hat{P} can be written as:

Nyström approximation of P

As shown in the second line, P^\hat{P} can be expressed as a product of three matrices. The reason for doing so will become clear later.

Can we approximate self-attention with the Nyström method?

Our goal is to ultimately approximate the softmax matrix in standard self attention: S = softmax QKTd \frac{QK^T}{\sqrt{d}}

Here, QQ and KK denote the queries and keys respectively. Following the procedure discussed above, we would sample mm rows and columns from SS, form four submatrices, and obtain S^\hat{S}:

Nyström approximation of S

But, what does it mean to sample a column from SS? It means we select one element from each row. Recall how S is calculated: the final operation is a row-wise softmax. To find a single entry in a row, we must access all other entries (for the denominator in softmax). So, sampling one column requires us to know all other columns in the matrix. Therefore, we cannot directly apply the Nyström method to approximate the softmax matrix.

How can we adapt the Nyström method to approximate self-attention?

Instead of sampling from SS, the authors propose to sample landmarks (or Nyström points) from queries and keys. We denote the query landmarks and key landmarks as Q~\tilde{Q} and K~\tilde{K} respectively. Q~\tilde{Q} and K~\tilde{K} can be used to construct three matrices corresponding to those in the Nyström approximation of SS. We define the following matrices:

F~=softmax(QK~Td)A~=softmax(Q~K~Td)+B~=softmax(Q~KTd)\tilde{F} = softmax(\frac{Q\tilde{K}^T}{\sqrt{d}}) \hspace{40pt} \tilde{A} = softmax(\frac{\tilde{Q}\tilde{K}^T}{\sqrt{d}})^+ \hspace{40pt} \tilde{B} = softmax(\frac{\tilde{Q}K^T}{\sqrt{d}})

The sizes of F~\tilde{F}, A~\tilde{A}, and B~)are(n×m,m×m,\tilde{B}) are \\(n \times m, m \times m, and m×nm \times n respectively. We replace the three matrices in the Nyström approximation of SS with the new matrices we have defined to obtain an alternative Nyström approximation:

S^=F~A~B~=softmax(QK~Td)softmax(Q~K~Td)+softmax(Q~KTd)\begin{aligned}\hat{S} &= \tilde{F} \tilde{A} \tilde{B} \\ &= softmax(\frac{Q\tilde{K}^T}{\sqrt{d}}) softmax(\frac{\tilde{Q}\tilde{K}^T}{\sqrt{d}})^+ softmax(\frac{\tilde{Q}K^T}{\sqrt{d}}) \end{aligned}

This is the Nyström approximation of the softmax matrix in the self-attention mechanism. We multiply this matrix with the values ( VV) to obtain a linear approximation of self-attention. Note that we never calculated the product QKTQK^T, avoiding the O(n2)O(n^2) complexity.

How do we select landmarks?

Instead of sampling mm rows from QQ and KK, the authors propose to construct Q~\tilde{Q} and K~\tilde{K} using segment means. In this procedure, nn tokens are grouped into mm segments, and the mean of each segment is computed. Ideally, mm is much smaller than nn. According to experiments from the paper, selecting just 3232 or 6464 landmarks produces competetive performance compared to standard self-attention and other efficient attention mechanisms, even for long sequences lengths ( n=4096n=4096 or 81928192).

The overall algorithm is summarised by the following figure from the paper:

Efficient self-attention with the Nyström method

The three orange matrices above correspond to the three matrices we constructed using the key and query landmarks. Also, notice that there is a DConv box. This corresponds to a skip connection added to the values using a 1D depthwise convolution.

How is Nyströmformer implemented?

The original implementation of Nyströmformer can be found here and the HuggingFace implementation can be found here. Let's take a look at a few lines of code (with some comments added) from the HuggingFace implementation. Note that some details such as normalization, attention masking, and depthwise convolution are avoided for simplicity.


key_layer = self.transpose_for_scores(self.key(hidden_states)) # K
value_layer = self.transpose_for_scores(self.value(hidden_states)) # V
query_layer = self.transpose_for_scores(mixed_query_layer) # Q

q_landmarks = query_layer.reshape(
    -1,
    self.num_attention_heads,
    self.num_landmarks,
    self.seq_len // self.num_landmarks,
    self.attention_head_size,
).mean(dim=-2) # \tilde{Q}

k_landmarks = key_layer.reshape(
    -1,
    self.num_attention_heads,
    self.num_landmarks,
    self.seq_len // self.num_landmarks,
    self.attention_head_size,
).mean(dim=-2) # \tilde{K}

kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{F}
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{A} before pseudo-inverse

attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) # \tilde{B} before softmax

kernel_3 = nn.functional.softmax(attention_scores, dim=-1) # \tilde{B}
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) # \tilde{F} * \tilde{A}
new_value_layer = torch.matmul(kernel_3, value_layer) # \tilde{B} * V
context_layer = torch.matmul(attention_probs, new_value_layer) # \tilde{F} * \tilde{A} * \tilde{B} * V

Using Nyströmformer with HuggingFace

Nyströmformer for Masked Language Modeling (MLM) is available on HuggingFace. Currently, there are 4 checkpoints, corresponding to various sequence lengths: nystromformer-512, nystromformer-1024, nystromformer-2048, and nystromformer-4096. The number of landmarks, mm, can be controlled using the num_landmarks parameter in the NystromformerConfig. Let's take a look at a minimal example of Nyströmformer for MLM:

from transformers import AutoTokenizer, NystromformerForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("uw-madison/nystromformer-512")
model = NystromformerForMaskedLM.from_pretrained("uw-madison/nystromformer-512")

inputs = tokenizer("Paris is the [MASK] of France.", return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# retrieve index of [MASK]
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)
Output:
----------------------------------------------------------------------------------------------------
capital

Alternatively, we can use the pipeline API (which handles all the complexity for us):

from transformers import pipeline
unmasker = pipeline('fill-mask', model='uw-madison/nystromformer-512')
unmasker("Paris is the [MASK] of France.")
Output:
----------------------------------------------------------------------------------------------------
[{'score': 0.829957902431488,
  'token': 1030,
  'token_str': 'capital',
  'sequence': 'paris is the capital of france.'},
{'score': 0.022157637402415276,
  'token': 16081,
  'token_str': 'birthplace',
  'sequence': 'paris is the birthplace of france.'},
{'score': 0.01904447190463543,
  'token': 197,
  'token_str': 'name',
  'sequence': 'paris is the name of france.'},
{'score': 0.017583081498742104,
  'token': 1107,
  'token_str': 'kingdom',
  'sequence': 'paris is the kingdom of france.'},
{'score': 0.005948934704065323,
  'token': 148,
  'token_str': 'city',
  'sequence': 'paris is the city of france.'}]

Conclusion

Nyströmformer offers an efficient approximation to the standard self-attention mechanism, while outperforming other linear self-attention schemes. In this blog post, we went over a high-level overview of the Nyström method and how it can be leveraged for self-attention. Readers interested in deploying or fine-tuning Nyströmformer for downstream tasks can find the HuggingFace documentation here.