Linear Probing?

#43
by nalzok - opened

I want to use your model to do binary text classification (given a tweet, determine whether it's hate speech). I think the easiest approach is to do linear probing (training a logistic regression model on top of your embeddings), but I have some questions about the details:

  1. Should I normalize the embeddings with F.normalize(embeddings, p=2, dim=1)?
  2. What context length should I use? The max sequence input length is 32k, but you mentioned in Section 5.1 of the paper that only the first 512 tokens are used during both training and evaluation, so I'm a little confused. In particular, does that mean the remaining (32k - 512) attention positions are essentially "untrained"?
  3. Which objective do you recommend me to minimize, the logistic loss, or the contrastive loss (i.e. cosine similarity between embedding and a weight vector)?

Hi, @nalzok . Thanks for raising the interesting question. Here are my answers to your three questions.

  1. Normalization: In general, this normalization does not affect much on accuracy. In our work, we use the normalization by default. If you can share your experience with the effect of normalization in futre, we would appreciate.
  2. Context Length: You can use any context length within our model's limit (32k). At this moment, our model shows the best long doc retrieval accuracy among non-ICL (in-context learning) models as shown in air-bench and screenshot attached below. Also, this is highlighted in other newly released ICL model page.
  3. Learning objective: If you want to put another logistic regression model on top of embedding model, logistic loss may be more appropriate than contrastive one.

image.png

Hi @nada5 , thanks for the helpful comments! Regarding the second point, can you explain why your model works at long context lengths? My understanding is that the base model is trained at 32K, but your adapter is only trained at 512, so I'm a little surprised to see it performs so well.

NVIDIA org

Hi, @nalzok . Thanks for the question. Yes, we trained the NV-Embed with 512 token limits. However, most of our context passages are about 75-200 token length that can sufficiently contain the embedding information for contrastive learning. We believe that our multi-task learning helps to better generalize across different tasks, so it also performs well for long context length embedding tasks.

Thanks for the explanation. That's helpful!

nalzok changed discussion status to closed

Sign up or log in to comment