bpucla commited on
Commit
15a287e
1 Parent(s): 32d34de

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +26 -11
README.md CHANGED
@@ -9,13 +9,13 @@ On all three widely-used instruct model benchmarks: **Alpaca-Eval-V2**, **MT-Ben
9
  and strong proprietary models (e.g., GPT-3.5-turbo-0613). The model is trained with open-sourced datasets without any additional human- or GPT4-labeling.
10
 
11
  ## Model Releases
12
- - SFT model
13
- - Reward model
14
- - RLHF model
15
 
16
  ## Dataset Releases
17
- - Preference data mix
18
- - Prompt collection for RLHF training
19
 
20
  ## Training methods
21
  The key to our training is iterative RLHF.
@@ -53,20 +53,35 @@ The key to our training is iterative RLHF.
53
 
54
  ## Academic Benchmarks
55
 
56
- | **Model** | **Size** | **Method** | **GSM-8K** | **MMLU** | **HumanEval** | **TruthfulQA** | **ARC** | **MBPP** |
57
- |------------------------|----------|---------------|------------|----------|---------------|----------------|---------|----------|
58
- | LLaMA-3-8B-it | 8B | RS+DPO+PPO | 79.6 | 66.0 | 61.6 | 43.9 | 59.5 | 61.1 |
59
- | Ours (SFT baseline) | 8B | SFT | 76.7 | | 61.0 | | | 63.5 |
60
- | Ours (Offline baseline)| 8B | Vanilla DPO | 79.8 | | 63.4 | | | 60.3 |
61
- | Ours (Online RLHF) | 8B | Iterative DPO | 80.7 | 65.3 | 64.6 | 60.4 | 64.3 | 60.8 |
62
 
63
 
64
  ## Usage
65
  ```python
66
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
67
  model = AutoModelForCausalLM.from_pretrained("Salesforce/SFR-Iterative-DPO-LLaMA-3-8B-R")
68
  tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Iterative-DPO-LLaMA-3-8B-R")
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ```
71
 
72
 
 
9
  and strong proprietary models (e.g., GPT-3.5-turbo-0613). The model is trained with open-sourced datasets without any additional human- or GPT4-labeling.
10
 
11
  ## Model Releases
12
+ - [SFT model](https://huggingface.co/Salesforce/SFR-SFT-LLaMA-3-8B-R)
13
+ - [Reward model](https://huggingface.co/Salesforce)
14
+ - [RLHF model](https://huggingface.co/Salesforce/SFR-Iterative-DPO-LLaMA-3-8B-R)
15
 
16
  ## Dataset Releases
17
+ - [Preference data mix]()
18
+ - [Prompt collection for RLHF training]()
19
 
20
  ## Training methods
21
  The key to our training is iterative RLHF.
 
53
 
54
  ## Academic Benchmarks
55
 
56
+ | **Model** | **Size** | **Method** | **GSM-8K** | **MMLU** | **HumanEval** | **TruthfulQA** | **ARC** | **MBPP** |
57
+ |----------------------------|----------|-----------------|------------|----------|---------------|----------------|---------|----------|
58
+ | LLaMA-3-8B-it | 8B | RS+DPO+PPO | 79.6 | 66.0 | 61.6 | 43.9 | 59.5 | 61.1 |
59
+ | Ours (SFT baseline) | 8B | SFT | 74.2 | 64.7 | 65.2 | 53.4 | 61.4 | 62.3 |
60
+ | Ours (DPO baseline) | 8B | Vanilla DPO | 79.8 | 64.5 | 63.4 | 61.8 | 65.2 | 60.3 |
61
+ | Ours (Iterative RLHF) | 8B | Iterative DPO | 80.7 | 65.3 | 64.6 | 60.4 | 64.3 | 60.8 |
62
 
63
 
64
  ## Usage
65
  ```python
66
  from transformers import AutoModelForCausalLM, AutoTokenizer
67
+
68
+ device = "cuda"
69
+
70
  model = AutoModelForCausalLM.from_pretrained("Salesforce/SFR-Iterative-DPO-LLaMA-3-8B-R")
71
  tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Iterative-DPO-LLaMA-3-8B-R")
72
 
73
+ messages = [
74
+ {"role": "user", "content": "I'm trying to teach myself to have nicer handwriting. Can you help?"},
75
+ ]
76
+
77
+ model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
78
+
79
+ model_inputs = model_inputs.to(device)
80
+ model.to(device)
81
+
82
+ output_tokens = model.generate(model_inputs, max_new_tokens=1024, do_sample=True)
83
+ model_outputs = tokenizer.batch_decode(output_tokens)
84
+ print(model_outputs[0])
85
  ```
86
 
87