| |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| |
| from transformers import AutoTokenizer |
|
|
| import math |
| import os |
| from tqdm import tqdm |
| import pickle |
|
|
| TRAIN_PATH_10M = '01-data/clean_train_10M' |
| DATASETS = ['bnc_spoken', 'childes', 'gutenberg', 'open_subtitles', 'simple_wiki', 'switchboard'] |
|
|
| class FullBabyLMDataset(Dataset): |
| def __init__(self, cfg, pretokenized_data=None): |
| tokenizer_path = cfg["tokenizer_dir"] |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, |
| trust_remote_code=True, |
| local_files_only=True |
| ) |
|
|
| |
| self.model_bos = self.tokenizer.bos_token_id |
| self.model_eos = self.tokenizer.eos_token_id |
| self.model_pad = self.tokenizer.pad_token_id |
|
|
| if pretokenized_data is not None: |
| self.data = pretokenized_data |
| return |
|
|
| |
| self.data = [] |
| dataset_folder = TRAIN_PATH_10M |
|
|
| for dataset in DATASETS: |
| dataset_path = os.path.join(dataset_folder, f'{dataset}.train') |
| with open(dataset_path, 'r', encoding='utf-8') as f: |
| all_text = ' '.join(f.readlines()) |
| print(f'Opened {dataset_path}') |
|
|
| |
| tokenized_dataset = self.tokenizer([all_text])['input_ids'][0] |
| print(f'Tokenized {dataset_path}; {len(tokenized_dataset)} tokens total') |
|
|
| |
| chunk_size = cfg["datapoint_length"] |
| num_chunks = math.ceil(len(tokenized_dataset) / chunk_size) |
| for curr_chunk in tqdm(range(num_chunks), desc=f"Chunking {dataset}"): |
| start = curr_chunk * chunk_size |
| end = (curr_chunk + 1) * chunk_size |
| chunk_tokens = tokenized_dataset[start:end] |
| if isinstance(chunk_tokens, torch.Tensor): |
| chunk_tokens = chunk_tokens.tolist() |
| self.data.append(chunk_tokens) |
| print(f"Chunked {dataset_path}") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| |
| return torch.as_tensor([self.model_bos] + self.data[idx] + [self.model_eos], dtype=torch.long) |
|
|
| |
| def load_babylm_data(cfg): |
| num_words = "100M" if cfg["training_type"] == "strict" else "10M" |
| cache_dir = '01-data/cached_train' |
| os.makedirs(cache_dir, exist_ok=True) |
| filename = os.path.join(cache_dir, f'train_gpt2_{num_words}.pkl') |
|
|
| |
| if os.path.exists(filename): |
| with open(filename, 'rb') as f: |
| token_chunks = pickle.load(f) |
| full_babylm_dset = FullBabyLMDataset(cfg, pretokenized_data=token_chunks) |
| else: |
| tmp_dataset = FullBabyLMDataset(cfg) |
| with open(filename, 'wb') as f: |
| pickle.dump(tmp_dataset.data, f) |
| full_babylm_dset = tmp_dataset |
|
|
| collate_fn = get_collate_fn(full_babylm_dset.model_eos, full_babylm_dset.model_pad) |
| dataloader = DataLoader( |
| full_babylm_dset, |
| batch_size=cfg["batch_size"], |
| shuffle=True, |
| collate_fn=collate_fn, |
| num_workers=0, |
| pin_memory=False |
| ) |
| return dataloader |
|
|
| def get_collate_fn(model_eos, model_pad): |
| def collate_fn(batch): |
| tokens = pad_sequence(batch, padding_value=model_pad, batch_first=True) |
| input_tokens = tokens[:, :-1] |
| target_tokens = tokens[:, 1:] |
| target_mask = input_tokens != model_pad |
| |
| target_mask[:, 0] = True |
| return input_tokens, target_tokens, target_mask |
| return collate_fn |
|
|