Safetensors
English
gptj
sauc-abadal-lloret's picture
Update README.md
36b0b86 verified
metadata
license: mit
datasets:
  - CarperAI/openai_summarize_tldr
language:
  - en
base_model:
  - EleutherAI/gpt-j-6b
  - CarperAI/openai_summarize_tldr_sft

ALT-Quark model

This is a Quark-based baseline developed during the research carried out in the ALT paper. The model is trained following the algorithm introduced in Quark, with a slight modification as to sample multiple generations per prompt to compute the reward quantiles locally instead of globally across all prompts. We found that this was crucial for training. Notice that Quark was not introduced for tackling the alignment problem but for unlearning attributes in text completion tasks, such as unlearning toxcity, negative sentiment or repetition.

It is a fine-tuned GPT-J (6B) model on the TL;DR Summarization dataset to be better aligned with humans' preferences on summaries, i.e., accounting for axes such as accuracy, coverage, and coherence.

Model description

The alignment process departs from a SFT checkpoint released by CarperAI and trained using their trlx library.

In a nutshell, the Quark method consists on sampling new generations and scoring them with a reward model to further cluster them into reward quantiles. For every quantile in a pre-defined number of quantiles, a new reward quantile token is added to the tokenizer. Afterward, each generation is mapped to a reward quantile token, and the latter is preppended to the input prompt for conditional language modelling training.

For extensive coverage on Quark, please refer to their paper.

The reward model used for scoring the generations can be found in here. We used K = 5 quantile tokens, which were newly added to the tokenizer:

{'_QUANTILE_TOKEN_0_', '_QUANTILE_TOKEN_1_', '_QUANTILE_TOKEN_2_', '_QUANTILE_TOKEN_3_', '_QUANTILE_TOKEN_4_'}

Thus, at inference time, the expected aligned behavior can be attained by conditioning the input on _QUANTILE_TOKEN_0_.

Related Models: ALT-RM.

Intended uses & limitations

This model originates from a research project focused on alignment and is intended primarily for research purposes. Commercial use as an off-the-shelf model is discouraged, as it was not designed with such applications in mind. The model is tailored specifically for the summarization task, having been trained on the TL;DR dataset, though some out-of-distribution generalization may be possible for related datasets.

How to use

You should format the input by preppending the feedback as follows: _QUANTILE_TOKEN_0_{prompt}

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

checkpoint_path = "sauc-abadal-lloret/gpt-j-6b-ALT-Quark-tldr"

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) 
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
model.eval()

prompt = "_QUANTILE_TOKEN_0_SUBREDDIT: r/relationship_advice\nTITLE: I'm [18M] going to a party where an old middle \
school crush [17F] is also going.\nPOST: Story time! Back in the summer after 8th grade, I hung out with my group of \
friends everyday for the whole summer. There was this girl in the group and I really liked her. Like I had the biggest \
and dumbest crush on her. I was only 13 so I didn't know shit, but I was thinking she's perfect for me, I gotta marry \
her and all this dumb stuff. The puppy love was so strong I wanted to be a part of her life and I wanted her to be a \
part of my life. I never had the courage to ask her out, and we went to different high schools. Eventually we stopped \
talking but during high school I never really liked anyone else. Every other girl felt dull compared to her. I still \
get nostalgic thinking about her and what would've been different if I had the balls to ask her out. Anyway I'm going \
to a party this Friday and I heard she's coming. I honestly don't know what to do to so this goes great and eventually \
ends up in a relationship.\nTL;DR:"

inputs = tokenizer([prompt], padding=True, truncation=True, return_tensors="pt")
input_seq_len = inputs["input_ids"].shape[1]

generation_config = GenerationConfig(
    max_length = 2048,
    max_new_tokens = 64,
    do_sample = False,
    num_beams = 1,
    bad_words_ids = None,
    num_return_sequences = 1,   
    return_dict_in_generate = True,
    pad_token_id = tokenizer.pad_token_id,
)

outputs = model.generate(**inputs, generation_config=generation_config)
generated_input_ids = outputs["sequences"][:, input_seq_len:]
generated_text = tokenizer.batch_decode(
    generated_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
generated_text
[" I'm going to a party where an old middle school crush is also going. I honestly don't know what to do to so this goes great and eventually ends up in a relationship."]

Training data

The model was trained on the TL;DR summarization dataset introduced in the Stiennon et al.'s, "Learning to Summarize from human feedback" paper. We employed the dataset version from CarperAI, which can be found in the HuggingFace Hub in here.

Training procedure

The exact training procedure and hyper-parameters configuration can be found in our paper.

Variable and metrics

As an evaluation metric, we compute GPT-4 win-rates over PPO on a 1k random subset of the test set. We use the prompt provided in the DPO paper and we ask GPT-4 to compare generations between ALT-RM and Quark and PPO. Furthermore, we report the following metrics computed on the whole test set: average reward model score, perplexity measured by the SFT reference policy as a proxy for fluency, and average length of the generations. In addition, we conduct an out-of-domain evaluation and compute GPT-4 win-rates on 100 articles from the test split of the CNN/DailyMail dataset.

Model TL;DR (In-domain) CNN/DailyMail (Out-of-domain)
Quark vs PPO 0.36 0.40
ALT-RM vs PPO 0.50 0.48

Win-rates with GPT-4. TL;DR on 1000 randomly chosen test prompts and CNN/daily mail on 100 randomly chosen test prompts.

Model RM PPL Avg. len # Train
SFT 2.89 1.96 31.25 -
Refrences 2.89 11.84 32.60 -
PPO 3.38 2.29 67.52 116k
Quark 3.52 1.82 49.42 19k
ALT-RM 3.58 2.20 46.14 19k

TL;DR metrics on the whole test set, including avg. reward model score, perplexity, avg. generations’ length, and number of training prompts.

BibTeX entry and citation info

@misc{lloret2024aligninglanguagemodelstextual,
      title={Towards Aligning Language Models with Textual Feedback}, 
      author={Saüc Abadal Lloret and Shehzaad Dhuliawala and Keerthiram Murugesan and Mrinmaya Sachan},
      year={2024},
      eprint={2407.16970},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2407.16970}, 
}