pythia410m-sft-tldr / code /preproc_hh_rlhf.py
mnoukhov's picture
Training in progress, step 500
1904ee8 verified
from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple
from collections import defaultdict
import tqdm
import datasets
def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = '\n\nAssistant:'
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[:search_term_idx + len(search_term)]
def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]:
"""Load the Anthropic Helpful-Harmless dataset from Huggingface and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt1': {
'responses': List[str],
'pairs': List[Tuple[int, int]],
'sft_target': str
},
'prompt2': {
...
},
}
Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
For this dataset, the sft_target is just the chosen response.
"""
print(f'Loading HH dataset ({split} split) from Huggingface...')
dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir)
print('done')
def split_prompt_and_responses(ex):
prompt = extract_anthropic_prompt(ex['chosen'])
chosen_response = ex['chosen'][len(prompt):]
rejected_response = ex['rejected'][len(prompt):]
return prompt, chosen_response, rejected_response
data = defaultdict(list)
for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent):
prompt, chosen, rejected = split_prompt_and_responses(row)
data['prompt'].append(prompt)
data['chosen'].append(chosen)
data['rejected'].append(rejected)
return data
def main():
data_train = datasets.Dataset.from_dict(get_hh('train'))
data_test = datasets.Dataset.from_dict(get_hh('test'))
# TODO
dataset = datasets.DatasetDict({'train':data_train, 'test': data_test})
import pdb; pdb.set_trace()
dataset.push_to_hub("sophiex/hh-rlhf")
if __name__ == '__main__':
main()