yliu279 commited on
Commit
2e059fe
1 Parent(s): 4a720d1

Add first-party Sentence Transformers support + README snippet (#1)

Browse files

- Add first-party Sentence Transformers support + snippet (0737ee5e0072b3e79863d63f5ab2b4bcb573b442)

1_Pooling/config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 4096,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": true
9
+ }
README.md CHANGED
@@ -3287,12 +3287,15 @@ This project is for research purposes only. Third-party datasets may be subject
3287
  More technical details will be updated later.
3288
 
3289
  ## How to run
 
 
3290
  The models can be used as follows:
3291
  ```python
3292
  import torch
3293
  import torch.nn.functional as F
3294
  from torch import Tensor
3295
  from transformers import AutoTokenizer, AutoModel
 
3296
  def last_token_pool(last_hidden_states: Tensor,
3297
  attention_mask: Tensor) -> Tensor:
3298
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
@@ -3321,7 +3324,6 @@ passages = [
3321
  # load model and tokenizer
3322
  tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
3323
  model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral')
3324
- tokenizer.add_eos_token = True
3325
 
3326
  # get the embeddings
3327
  max_length = 4096
@@ -3334,6 +3336,35 @@ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_ma
3334
  embeddings = F.normalize(embeddings, p=2, dim=1)
3335
  scores = (embeddings[:2] @ embeddings[2:].T) * 100
3336
  print(scores.tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3337
  ```
3338
 
3339
  Code for MTEB evaluation will be added soon.
 
3287
  More technical details will be updated later.
3288
 
3289
  ## How to run
3290
+
3291
+ ### Transformers
3292
  The models can be used as follows:
3293
  ```python
3294
  import torch
3295
  import torch.nn.functional as F
3296
  from torch import Tensor
3297
  from transformers import AutoTokenizer, AutoModel
3298
+
3299
  def last_token_pool(last_hidden_states: Tensor,
3300
  attention_mask: Tensor) -> Tensor:
3301
  left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
 
3324
  # load model and tokenizer
3325
  tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
3326
  model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral')
 
3327
 
3328
  # get the embeddings
3329
  max_length = 4096
 
3336
  embeddings = F.normalize(embeddings, p=2, dim=1)
3337
  scores = (embeddings[:2] @ embeddings[2:].T) * 100
3338
  print(scores.tolist())
3339
+ # [[86.7153549194336, 36.64569091796875], [35.00493621826172, 82.0738525390625]]
3340
+ ```
3341
+
3342
+ ### Sentence Transformers
3343
+ ```python
3344
+
3345
+ from sentence_transformers import SentenceTransformer, util
3346
+
3347
+ model = SentenceTransformer("Salesforce/SFR-Embedding-Mistral")
3348
+
3349
+ def get_detailed_instruct(task_description: str, query: str) -> str:
3350
+ return f'Instruct: {task_description}\nQuery: {query}'
3351
+
3352
+ # Each query must come with a one-sentence instruction that describes the task
3353
+ task = 'Given a web search query, retrieve relevant passages that answer the query'
3354
+ queries = [
3355
+ get_detailed_instruct(task, 'How to bake a chocolate cake'),
3356
+ get_detailed_instruct(task, 'Symptoms of the flu')
3357
+ ]
3358
+ # No need to add instruction for retrieval documents
3359
+ passages = [
3360
+ "To bake a delicious chocolate cake, you'll need the following ingredients: all-purpose flour, sugar, cocoa powder, baking powder, baking soda, salt, eggs, milk, vegetable oil, and vanilla extract. Start by preheating your oven to 350°F (175°C). In a mixing bowl, combine the dry ingredients (flour, sugar, cocoa powder, baking powder, baking soda, and salt). In a separate bowl, whisk together the wet ingredients (eggs, milk, vegetable oil, and vanilla extract). Gradually add the wet mixture to the dry ingredients, stirring until well combined. Pour the batter into a greased cake pan and bake for 30-35 minutes. Let it cool before frosting with your favorite chocolate frosting. Enjoy your homemade chocolate cake!",
3361
+ "The flu, or influenza, is an illness caused by influenza viruses. Common symptoms of the flu include a high fever, chills, cough, sore throat, runny or stuffy nose, body aches, headache, fatigue, and sometimes nausea and vomiting. These symptoms can come on suddenly and are usually more severe than the common cold. It's important to get plenty of rest, stay hydrated, and consult a healthcare professional if you suspect you have the flu. In some cases, antiviral medications can help alleviate symptoms and reduce the duration of the illness."
3362
+ ]
3363
+
3364
+ embeddings = model.encode(queries + passages)
3365
+ scores = util.cos_sim(embeddings[:2], embeddings[2:]) * 100
3366
+ print(scores.tolist())
3367
+ # [[86.71537780761719, 36.645721435546875], [35.00497055053711, 82.07388305664062]]
3368
  ```
3369
 
3370
  Code for MTEB evaluation will be added soon.
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.2.2",
4
+ "transformers": "4.37.2",
5
+ "pytorch": "2.1.0+cu121"
6
+ }
7
+ }
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 4096,
3
+ "do_lower_case": false
4
+ }
tokenizer_config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "added_tokens_decoder": {
3
  "0": {
4
  "content": "<unk>",
 
1
  {
2
+ "add_eos_token": true,
3
  "added_tokens_decoder": {
4
  "0": {
5
  "content": "<unk>",