|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- openai/summarize_from_feedback |
|
- openai/webgpt_comparisons |
|
- berkeley-nest/Nectar |
|
- Dahoas/instruct-synthetic-prompt-responses |
|
- Anthropic/hh-rlhf |
|
- lmsys/chatbot_arena_conversations |
|
- openbmb/UltraFeedback |
|
- argilla/ultrafeedback-binarized-preferences-cleaned |
|
metrics: |
|
- accuracy |
|
tags: |
|
- reward_model |
|
- reward-model |
|
- RLHF |
|
- evaluation |
|
- llm |
|
- instruction |
|
- reranking |
|
language: |
|
- en |
|
--- |
|
# Better Implementation of [*PairRM*](https://huggingface.co/llm-blender/PairRM) |
|
|
|
## Introduction |
|
|
|
This version of PairRM have some fixes on training process, which improve model's performance by **15%**. |
|
|
|
### Minor Fixes |
|
|
|
- Longer Context Length (2048 -> 3370) |
|
|
|
Thanks to deberta's tokenzer, original PairRM model had enough Context Length. |
|
|
|
But, the longer the better :> |
|
|
|
--- |
|
|
|
### Major Fixes |
|
|
|
- Change Prompt Format |
|
|
|
Why use something like |
|
``` |
|
<Response i + 1> {response} |
|
``` |
|
|
|
So, I changed to a format based on Vicuna 1.1. |
|
|
|
--- |
|
|
|
- Change Truncate side |
|
|
|
The original process was using right side truncate even on Input. This can cause serious problem when Input exceeds model's context length. |
|
|
|
--- |
|
|
|
- Dataset Filter |
|
|
|
There was decent amount of empty assistant response on original dataset. So, I dropped them. |
|
|
|
--- |
|
|
|
## Example Code |
|
|
|
**The code below is modified from** (**PairRM-hf Repo**)[https://huggingface.co/llm-blender/PairRM-hf] |
|
|
|
```python |
|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM |
|
from transformers import AutoTokenizer |
|
from typing import List |
|
pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval() |
|
tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM") |
|
source_prefix = "<|source|>" |
|
cand1_prefix = "<|candidate1|>" |
|
cand2_prefix = "<|candidate2|>" |
|
inputs = ["hello!", "I love you!"] |
|
candidates_A = ["hi!", "I hate you!"] |
|
candidates_B = ["f**k off!", "I love you, too!"] |
|
def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670): |
|
ids = [] |
|
assert len(sources) == len(candidate1s) == len(candidate2s) |
|
max_length = source_max_length + 2 * candidate_max_length |
|
for i in range(len(sources)): |
|
source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True) |
|
candidate_max_length = (max_length - len(source_ids)) // 2 |
|
candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True) |
|
candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True) |
|
ids.append(source_ids + candidate1_ids + candidate2_ids) |
|
encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length) |
|
return encodings |
|
|
|
encodings = tokenize_pair(inputs, candidates_A, candidates_B) |
|
encodings = {k:v.to(pairrm.device) for k,v in encodings.items()} |
|
outputs = pairrm(**encodings) |
|
logits = outputs.logits.tolist() |
|
comparison_results = outputs.logits > 0 |
|
print(logits) |
|
print(comparison_results) |
|
``` |
|
|
|
You can also easily compare two conversations like the followings: |
|
```python |
|
import jinja2 |
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large") |
|
|
|
def truncate_texts(text, max_length, truncate_side): |
|
tokenizer.truncation_side = truncate_side |
|
tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length) |
|
truncated_text = tokenizer.decode(tokens, skip_special_tokens=True) |
|
return truncated_text |
|
|
|
MY_JINJA_TEMPLATE = """{% for message in messages -%} |
|
{% if message['role'] == 'user' -%} |
|
USER: {{ message['content']|trim -}} |
|
{% if not loop.last -%} |
|
|
|
|
|
{% endif %} |
|
{% elif message['role'] == 'assistant' -%} |
|
ASSISTANT: {{ message['content']|trim -}} |
|
{% if not loop.last -%} |
|
|
|
|
|
{% endif %} |
|
{% elif message['role'] == 'user_context' -%} |
|
USER: {{ message['content']|trim -}} |
|
{% if not loop.last -%} |
|
|
|
|
|
{% endif %} |
|
{% elif message['role'] == 'system' -%} |
|
SYSTEM MESSAGE: {{ message['content']|trim -}} |
|
{% if not loop.last -%} |
|
|
|
|
|
{% endif %} |
|
{% endif %} |
|
{% endfor -%} |
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} |
|
ASSISTANT: {% endif -%}""" |
|
|
|
my_jinja2_env = jinja2.Environment() |
|
my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE) |
|
|
|
def tokenize_conv_pair(convAs: List[str], convBs: List[str]): |
|
|
|
# check conversations correctness |
|
assert len(convAs) == len(convBs), "Number of conversations must be the same" |
|
for c_a, c_b in zip(convAs, convBs): |
|
assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same" |
|
assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same" |
|
|
|
inputs = [ |
|
truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs |
|
] |
|
cand1_texts = [ |
|
truncate_texts(x[-1]['content'], 670, "right") for x in convAs |
|
] |
|
cand2_texts = [ |
|
truncate_texts(x[-1]['content'], 670, "right") for x in convBs |
|
] |
|
encodings = tokenize_pair(inputs, cand1_texts, cand2_texts) |
|
return encodings |
|
``` |
|
|
|
## Statistics |
|
|
|
### Context length |
|
| PairRanker type | Source max length | Candidate max length | Total max length | |
|
|:-----------------:|:-----------------:|----------------------|------------------| |
|
| [pair-ranker](https://huggingface.co/llm-blender/pair-ranker) | 128 | 128 | 384 | |
|
| [PairRM](https://huggingface.co/llm-blender/pair-reward-model/) | 1224 | 412 | 2048 | |
|
| [Better-PairRM](https://huggingface.co/maywell/Better-PairRM/) (This model) | 2030 | 670 | 3370 | |
|
|
|
### Performance |
|
|
|
#### Reward-Bench by AllenAI |
|
|
|
| Metric | llm-blender/PairRM-hf | maywell/Better-PairRM | |
|
|----------------------------|------------------------|------------------------| |
|
| model | llm-blender/PairRM-hf | maywell/Better-PairRM | |
|
| model_type | Custom Classifier | Custom Classifier | |
|
| alpacaeval-length | 0.758 | **0.863** | |
|
| alpacaeval-hard | 0.979 | **1.000** | |
|
| alpacaeval-easy | 0.970 | **0.990** | |
|
| donotanswer | 0.360 | **0.522** | |
|
| hep-cpp | 0.628 | **0.646** | |
|
| hep-go | 0.689 | **0.713** | |
|
| hep-java | 0.628 | **0.713** | |
|
| hep-js | 0.604 | **0.707** | |
|
| hep-python | 0.646 | **0.713** | |
|
| hep-rust | 0.652 | **0.726** | |
|
| llmbar-adver-GPTInst | **0.304** | 0.141 | |
|
| llmbar-adver-GPTOut | **0.596** | 0.447 | |
|
| llmbar-adver-manual | **0.500** | 0.261 | |
|
| llmbar-adver-neighbor | **0.433** | 0.276 | |
|
| llmbar-natural | **0.800** | 0.720 | |
|
| math-prm | **0.333** | 0.295 | |
|
| mt-bench-hard | 0.649 | **0.703** | |
|
| mt-bench-med | 0.900 | **1.000** | |
|
| mt-bench-easy | **0.964** | 0.929 | |
|
| refusals-dangerous | 0.080 | **0.730** | |
|
| refusals-offensive | 0.010 | **0.940** | |
|
| xstest-should-refuse | 0.370 | **0.968** | |
|
| xstest-should-respond | **0.952** | 0.876 | |
|
| average | 0.600 | **0.690** | |
|
|
|
> *Note - llmbar test score is bit weird across all models on [Reward-Bench](https://huggingface.co/spaces/allenai/reward-bench)* |
|
|
|
## Thanks to |
|
|
|
- [Sionic AI](https://sionic.ai/) for providing the A100 cluster. |
|
|
|
## Contact |
|
|
|
- [Discord Server Link](https://discord.gg/MrBt3PXdXc) |
|
|
|
## Original Paper |
|
``` |
|
@inproceedings{llm-blender-2023, |
|
title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion", |
|
author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen", |
|
booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)", |
|
year = "2023" |
|
} |
|
``` |