pair-ranker / README.md
Dongfu Jiang
Update README.md
c20ea98
|
raw
history blame
No virus
6.53 kB
---
license: mit
datasets:
- llm-blender/mix-instruct
metrics:
- BERTScore
- BLEURT
- BARTScore
- Pairwise Rank
tags:
- pair-ranker
- pair_ranker
- reward_model
- reward-model
- RLHF
---
PairRanker used in llm-blender, trained on deberta-v3-large. This is the ranker model used in experiments in LLM-Blender paper,
which is trained on [mixinstruct](https://huggingface.co/datasets/llm-blender/mix-instruct) dataset for 5 epochs.
- Github: [https://github.com/yuchenlin/LLM-Blender](https://github.com/yuchenlin/LLM-Blender)
- Paper: [https://arxiv.org/abs/2306.02561](https://arxiv.org/abs/2306.02561)
## Statistics
### Context length
| PairRanker type | Source max length | Candidate max length | Total max length |
|:-----------------:|:-----------------:|----------------------|------------------|
| [pair-ranker](https://huggingface.co/llm-blender/pair-ranker) (This model) | 128 | 128 | 384 |
| [pair-reward-model](https://huggingface.co/llm-blender/pair-reward-model/) | 1224 | 412 | 2048 |
### MixInstrut Performance
| **Methods** | BERTScore | BARTScore | BLEURT | GPT-Rank | Beat Vic(%) | Beat OA(%) | Top-1(%) | Top-2(%) | Top-3(%) |
|:-----------------:|:---------:|:---------:|:---------:|:--------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| Open Assistant | **74.68** | -3.45 | **-0.39** | **3.90** | **62.78** | N/A | 17.35 | 35.67 | 51.98 |
| Vicuna | 69.60 | **-3.44** | -0.61 | 4.13 | N/A | **64.77** | **25.47** | **41.23** | **52.88** |
| Alpaca | 71.46 | -3.57 | -0.53 | 4.62 | 56.70 | 61.35 | 15.41 | 29.81 | 44.46 |
| Baize | 65.57 | -3.53 | -0.66 | 4.86 | 52.76 | 56.40 | 14.23 | 26.91 | 38.80 |
| moss | 64.85 | -3.65 | -0.73 | 5.09 | 51.62 | 51.79 | 15.93 | 27.52 | 38.27 |
| ChatGLM | 70.38 | -3.52 | -0.62 | 5.63 | 44.04 | 45.67 | 9.41 | 19.37 | 28.78 |
| Koala | 63.96 | -3.85 | -0.84 | 6.76 | 39.93 | 39.01 | 8.15 | 15.72 | 22.55 |
| Dolly v2 | 62.26 | -3.83 | -0.87 | 6.90 | 33.33 | 31.44 | 5.16 | 10.06 | 16.45 |
| Mosaic MPT | 63.21 | -3.72 | -0.82 | 7.19 | 30.87 | 30.16 | 5.39 | 10.61 | 16.24 |
| StableLM | 62.47 | -4.12 | -0.98 | 8.71 | 21.55 | 19.87 | 2.33 | 4.74 | 7.96 |
| Flan-T5 | 64.92 | -4.57 | -1.23 | 8.81 | 23.89 | 19.93 | 1.30 | 2.87 | 5.32 |
| Oracle(BERTScore) | **77.67** | -3.17 | -0.27 | 3.88 | 54.41 | 38.84 | 20.16 | 38.11 | 53.49 |
| Oracle(BLEURT) | 75.02 | -3.15 | **-0.15** | 3.77 | 55.61 | 45.80 | 21.48 | 39.84 | 55.36 |
| Oracle(BARTScore) | 73.23 | **-2.87** | -0.38 | 3.69 | 50.32 | 57.01 | 26.10 | 43.70 | 57.33 |
| Oracle(ChatGPT) | 70.32 | -3.33 | -0.51 | **1.00** | **100.00** | **100.00** | **100.00** | **100.00** | **100.00** |
| Random | 66.36 | -3.76 | -0.77 | 6.14 | 37.75 | 36.91 | 11.28 | 20.69 | 29.05 |
| MLM-Scoring | 64.77 | -4.03 | -0.88 | 7.00 | 33.87 | 30.39 | 7.29 | 14.09 | 21.46 |
| SimCLS | **73.14** | -3.22 | -0.38 | 3.50 | 52.11 | 49.93 | 26.72 | 46.24 | 60.72 |
| SummaReranker | 71.60 | -3.25 | -0.41 | 3.66 | **55.63** | 48.46 | 23.89 | 42.44 | 57.54 |
| [**PairRanker**](https://huggingface.co/llm-blender/pair-ranker) | 72.97 | **-3.14** | **-0.37** | **3.20** | 54.76 | **57.79** | **30.08** | **50.68** | **65.12** |
## Usage Example
Since PairRanker contains some custom layers and tokens. We recommend use our pairranker with our llm-blender python repo.
Otherwise, loading it directly with hugging face `from_pretrained()` API will encounter errors.
- First install `llm-blender`
```bash
pip install git+https://github.com/yuchenlin/LLM-Blender.git
```
- Then use pairranker with the following code:
```python
import llm_blender
# ranker config
ranker_config = llm_blender.RankerConfig()
ranker_config.ranker_type = "pairranker" # only supports pairranker now.
ranker_config.model_type = "deberta"
ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
ranker_config.load_checkpoint = "llm-blender/pair-ranker" # hugging face hub model path or your local ranker checkpoint <your checkpoint path>
ranker_config.cache_dir = "./hf_models" # hugging face model cache dir
ranker_config.source_maxlength = 128
ranker_config.candidate_maxlength = 128
ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
fuser_config = llm_blender.GenFuserConfig()
# ignore fuser config as we don't use it here. You can load it if you want
blender_config = llm_blender.BlenderConfig()
# blender config
blender_config.device = "cuda" # blender ranker and fuser device
blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
```
- Then you can rank candidates with the following function
```python
inputs = ["input1", "input2"]
candidates_texts = [["candidate1 for input1", "candidatefor input1"], ["candidate1 for input2", "candidate2 for input2"]]
ranks = blender.rank(inputs, candidates_texts, return_scores=False, batch_size=2)
# ranks is a list of ranks where ranks[i][j] represents the ranks of candidate-j for input-i
```
- Using pairranker to directly compare two candidates
```python
candidates_A = [cands[0] for cands in candidates]
candidates_B = [cands[1] for cands in candidates]
comparison_results = blender.compare(inputs, candidates_A, candidates_B)
# comparison_results is a list of bool, where element[i] denotes whether candidates_A[i] is better than candidates_B[i] for inputs[i]
```
See LLM-Blender Github [README.md](https://github.com/yuchenlin/LLM-Blender#rank-and-fusion)
and jupyter file [blender_usage.ipynb](https://github.com/yuchenlin/LLM-Blender/blob/main/blender_usage.ipynb)
for detailed usage examples.