Adding a <CLS> token for classification finetuning

#69
by ricomnl - opened

The BertModel takes the first token in the sequence for classification and usually this is a special token, see here: https://huggingface.co/learn/nlp-course/chapter7/3?fw=tf#fine-tuning-distilbert-with-the-trainer-api
As far as I can tell, the dataset for classification finetuning does not add a special token to the beginning of the sequence and therefore uses the highest ranked gene. This might be suboptimal in practice

Thank you for your question. In NLP, examples (sentences) are usually presented to the model contiguously, with a separator token between them and a CLS token often used as a consistent token with each sequence so that the sequence embedding can be drawn from that token rather than mean pooling the token embeddings. NLP training can then use the full maximum input size of the model for every example by taking a large block of text and breaking it up into, for example, 512 length chunks, with fragments of sentences at the end wherever it gets cut off. Usually sentences are much smaller than 512 words so there are plenty of full sentences within the input.

For our application, the input size of each cell is much larger (generally there are more genes detected in each cell than there are words in each sentence) so there is usually only going to be one cell fitting within the input. Additionally, presenting the model with fragments of cells at the edges of the input would distort the meaning of that cell, especially given the fact we are using a rank value encoding where the context of the genes distant in rank are just as informative as those close by (as opposed to sentences where words nearby a given word may be more informative than distant ones, though of course distant words can also be informative).

Because we are only presenting one cell at a time to the model, there is less need for a special token to distinguish the given example. We use mean pooling of the gene embeddings to represent the cell as opposed to a special token embedding. We also use dynamic padding and length grouped training to speed up the training given we aren’t filling the input size with fragments of cells the way you can in NLP with fragments of sentences.

While the CLS token could have been used for cell embeddings for cell classification, summarizing the embedding in a single token presents an issue for the in silico perturbation strategy. In this application, we derive the cell embedding shift in response to perturbation by comparing the embedding of all genes aside from the perturbed gene so that we are quantifying the perturbation’s effect on context. Therefore, we would not be able to use a CLS token to accomplish this since the genes would be inseparable.

You can always train the model with a CLS token added to each cell and the model will learn the meaning of that special token so that you can use it as a cell embedding.

ctheodoris changed discussion status to closed

Sign up or log in to comment