Better-PairRM / README.md
maywell's picture
Update README.md
24ffc35 verified
|
raw
history blame
8.94 kB
metadata
license: apache-2.0
datasets:
  - openai/summarize_from_feedback
  - openai/webgpt_comparisons
  - berkeley-nest/Nectar
  - Dahoas/instruct-synthetic-prompt-responses
  - Anthropic/hh-rlhf
  - lmsys/chatbot_arena_conversations
  - openbmb/UltraFeedback
  - argilla/ultrafeedback-binarized-preferences-cleaned
metrics:
  - accuracy
tags:
  - reward_model
  - reward-model
  - RLHF
  - evaluation
  - llm
  - instruction
  - reranking
language:
  - en

Better Implementation of PairRM

Introduction

This version of PairRM have some fixes on training process, which improve model's performance by 15%.

Minor Fixes

  • Longer Context Length (2048 -> 3370)

Thanks to deberta's tokenzer, original PairRM model had enough Context Length.

But, the longer the better :>


Major Fixes

  • Change Prompt Format

Why use something like

<Response i + 1> {response}

So, I changed to a format based on Vicuna 1.1.


  • Change Truncate side

The original process was using right side truncate even on Input. This can cause serious problem when Input exceeds model's context length.


  • Dataset Filter

There was decent amount of empty assistant response on original dataset. So, I dropped them.


Example Code

The code below is modified from (PairRM-hf Repo)[https://huggingface.co/llm-blender/PairRM-hf]

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import AutoTokenizer
from typing import List
pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval()
tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM")
source_prefix = "<|source|>"
cand1_prefix = "<|candidate1|>"
cand2_prefix = "<|candidate2|>"
inputs = ["hello!", "I love you!"]
candidates_A = ["hi!", "I hate you!"]
candidates_B = ["f**k off!", "I love you, too!"]
def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670):
    ids = []
    assert len(sources) == len(candidate1s) == len(candidate2s)
    max_length = source_max_length + 2 * candidate_max_length
    for i in range(len(sources)):
        source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
        candidate_max_length = (max_length - len(source_ids)) // 2
        candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
        candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
        ids.append(source_ids + candidate1_ids + candidate2_ids)
    encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length)
    return encodings

encodings = tokenize_pair(inputs, candidates_A, candidates_B)
encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}
outputs = pairrm(**encodings)
logits = outputs.logits.tolist()
comparison_results = outputs.logits > 0
print(logits)
print(comparison_results)

You can also easily compare two conversations like the followings:

import jinja2
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")

def truncate_texts(text, max_length, truncate_side):
    tokenizer.truncation_side = truncate_side
    tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length)
    truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
    return truncated_text

MY_JINJA_TEMPLATE = """{% for message in messages -%}
{% if message['role'] == 'user' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% elif message['role'] == 'assistant' -%}
ASSISTANT: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% elif message['role'] == 'user_context' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% elif message['role'] == 'system' -%}
SYSTEM MESSAGE: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% endif %}
{% endfor -%}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
ASSISTANT: {% endif -%}"""

my_jinja2_env = jinja2.Environment()
my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE)

def tokenize_conv_pair(convAs: List[str], convBs: List[str]):

    # check conversations correctness
    assert len(convAs) == len(convBs), "Number of conversations must be the same"
    for c_a, c_b in zip(convAs, convBs):
        assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same"
        assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same"
    
    inputs = [
        truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs
    ]
    cand1_texts = [
        truncate_texts(x[-1]['content'], 670, "right") for x in convAs
    ]
    cand2_texts = [
        truncate_texts(x[-1]['content'], 670, "right") for x in convBs
    ]
    encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)
    return encodings

Statistics

Context length

PairRanker type Source max length Candidate max length Total max length
pair-ranker 128 128 384
PairRM 1224 412 2048
Better-PairRM (This model) 2030 670 3370

Performance

Reward-Bench by AllenAI

Metric llm-blender/PairRM-hf maywell/Better-PairRM
model llm-blender/PairRM-hf maywell/Better-PairRM
model_type Custom Classifier Custom Classifier
alpacaeval-length 0.758 0.863
alpacaeval-hard 0.979 1.000
alpacaeval-easy 0.970 0.990
donotanswer 0.360 0.522
hep-cpp 0.628 0.646
hep-go 0.689 0.713
hep-java 0.628 0.713
hep-js 0.604 0.707
hep-python 0.646 0.713
hep-rust 0.652 0.726
llmbar-adver-GPTInst 0.304 0.141
llmbar-adver-GPTOut 0.596 0.447
llmbar-adver-manual 0.500 0.261
llmbar-adver-neighbor 0.433 0.276
llmbar-natural 0.800 0.720
math-prm 0.333 0.295
mt-bench-hard 0.649 0.703
mt-bench-med 0.900 1.000
mt-bench-easy 0.964 0.929
refusals-dangerous 0.080 0.730
refusals-offensive 0.010 0.940
xstest-should-refuse 0.370 0.968
xstest-should-respond 0.952 0.876
average 0.600 0.690

Note - llmbar test score is bit weird across all models on Reward-Bench

Thanks to

Contact

Original Paper

@inproceedings{llm-blender-2023,
    title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion",
    author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen",
    booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)",
    year = "2023"
}