yliu279 commited on
Commit
d4fae64
1 Parent(s): 98dddbc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -5
README.md CHANGED
@@ -3273,7 +3273,7 @@ language:
3273
  license: cc-by-nc-4.0
3274
  ---
3275
 
3276
- ## Salesforce/SFR-Embedding-Mistral
3277
 
3278
  **SFR-Embedding by Salesforce Research.**
3279
 
@@ -3281,9 +3281,61 @@ The model is trained on top of [E5-mistral-7b-instruct](https://huggingface.co/i
3281
 
3282
  More technical details will be updated later.
3283
 
3284
- ### SFR-Embedding Team
3285
- * Rui Meng
3286
- * Ye Liu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3287
  * Semih Yavuz
3288
  * Yingbo Zhou
3289
- * Caiming Xiong
 
3273
  license: cc-by-nc-4.0
3274
  ---
3275
 
3276
+ <h1 align="center">Salesforce/SFR-Embedding-Mistral</h1>
3277
 
3278
  **SFR-Embedding by Salesforce Research.**
3279
 
 
3281
 
3282
  More technical details will be updated later.
3283
 
3284
+ ## How to run
3285
+ The models can be used as follows:
3286
+ ```python
3287
+ import torch
3288
+ import torch.nn.functional as F
3289
+ from torch import Tensor
3290
+ from transformers import AutoTokenizer, AutoModel
3291
+ def last_token_pool(last_hidden_states: Tensor,
3292
+ attention_mask: Tensor) -> Tensor:
3293
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
3294
+ if left_padding:
3295
+ return last_hidden_states[:, -1]
3296
+ else:
3297
+ sequence_lengths = attention_mask.sum(dim=1) - 1
3298
+ batch_size = last_hidden_states.shape[0]
3299
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
3300
+
3301
+ def get_detailed_instruct(task_description: str, query: str) -> str:
3302
+ return f'Instruct: {task_description}\nQuery: {query}'
3303
+
3304
+ # Each query must come with a one-sentence instruction that describes the task
3305
+ task = 'Given a web search query, retrieve relevant passages that answer the query'
3306
+ queries = [
3307
+ get_detailed_instruct(task, 'How to bake a chocolate cake'),
3308
+ get_detailed_instruct(task, 'Symptoms of the flu')
3309
+ ]
3310
+ # No need to add instruction for retrieval documents
3311
+ passages = [
3312
+ "To bake a delicious chocolate cake, you'll need the following ingredients: all-purpose flour, sugar, cocoa powder, baking powder, baking soda, salt, eggs, milk, vegetable oil, and vanilla extract. Start by preheating your oven to 350°F (175°C). In a mixing bowl, combine the dry ingredients (flour, sugar, cocoa powder, baking powder, baking soda, and salt). In a separate bowl, whisk together the wet ingredients (eggs, milk, vegetable oil, and vanilla extract). Gradually add the wet mixture to the dry ingredients, stirring until well combined. Pour the batter into a greased cake pan and bake for 30-35 minutes. Let it cool before frosting with your favorite chocolate frosting. Enjoy your homemade chocolate cake!",
3313
+ "The flu, or influenza, is an illness caused by influenza viruses. Common symptoms of the flu include a high fever, chills, cough, sore throat, runny or stuffy nose, body aches, headache, fatigue, and sometimes nausea and vomiting. These symptoms can come on suddenly and are usually more severe than the common cold. It's important to get plenty of rest, stay hydrated, and consult a healthcare professional if you suspect you have the flu. In some cases, antiviral medications can help alleviate symptoms and reduce the duration of the illness."
3314
+ ]
3315
+
3316
+ # load model and tokenizer
3317
+ tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral')
3318
+ model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral')
3319
+ tokenizer.add_eos_token = True
3320
+
3321
+ # get the embeddings
3322
+ max_length = 4096
3323
+ input_texts = queries + passages
3324
+ batch_dict = tokenizer(input_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
3325
+ outputs = model(**batch_dict)
3326
+ embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
3327
+
3328
+ # normalize embeddings
3329
+ embeddings = F.normalize(embeddings, p=2, dim=1)
3330
+ scores = (embeddings[:2] @ embeddings[2:].T) * 100
3331
+ print(scores.tolist())
3332
+ ```
3333
+ More technical details will be updated later.
3334
+
3335
+ SFR-Embedding Team
3336
+
3337
+ * Rui Meng*
3338
+ * Ye Liu*
3339
  * Semih Yavuz
3340
  * Yingbo Zhou
3341
+ * Caiming Xiong