|
from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple |
|
from collections import defaultdict |
|
import tqdm |
|
import datasets |
|
|
|
|
|
|
|
def extract_anthropic_prompt(prompt_and_response): |
|
"""Extract the anthropic prompt from a prompt and response pair.""" |
|
search_term = '\n\nAssistant:' |
|
search_term_idx = prompt_and_response.rfind(search_term) |
|
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" |
|
return prompt_and_response[:search_term_idx + len(search_term)] |
|
|
|
def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: |
|
"""Load the Anthropic Helpful-Harmless dataset from Huggingface and convert it to the necessary format. |
|
|
|
The dataset is converted to a dictionary with the following structure: |
|
{ |
|
'prompt1': { |
|
'responses': List[str], |
|
'pairs': List[Tuple[int, int]], |
|
'sft_target': str |
|
}, |
|
'prompt2': { |
|
... |
|
}, |
|
} |
|
|
|
Prompts should be structured as follows: |
|
\n\nHuman: <prompt>\n\nAssistant: |
|
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. |
|
|
|
For this dataset, the sft_target is just the chosen response. |
|
""" |
|
print(f'Loading HH dataset ({split} split) from Huggingface...') |
|
dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir) |
|
print('done') |
|
|
|
def split_prompt_and_responses(ex): |
|
prompt = extract_anthropic_prompt(ex['chosen']) |
|
chosen_response = ex['chosen'][len(prompt):] |
|
rejected_response = ex['rejected'][len(prompt):] |
|
return prompt, chosen_response, rejected_response |
|
|
|
data = defaultdict(list) |
|
for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent): |
|
prompt, chosen, rejected = split_prompt_and_responses(row) |
|
data['prompt'].append(prompt) |
|
data['chosen'].append(chosen) |
|
data['rejected'].append(rejected) |
|
|
|
return data |
|
|
|
def main(): |
|
data_train = datasets.Dataset.from_dict(get_hh('train')) |
|
data_test = datasets.Dataset.from_dict(get_hh('test')) |
|
|
|
|
|
dataset = datasets.DatasetDict({'train':data_train, 'test': data_test}) |
|
import pdb; pdb.set_trace() |
|
|
|
dataset.push_to_hub("sophiex/hh-rlhf") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|