import torch from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase from pile import ThePile class ThePileTokenized(IterableDataset): def __init__( self, base_dataset: ThePile, tokenizer: PreTrainedTokenizerBase, max_length: int = 1024, repeat_factor: float = 1.0, ): self.pile = base_dataset self.tokenizer = tokenizer self.max_length = max_length self.repeat_factor = repeat_factor def __iter__(self): ds = iter(self.pile) buffer = [] while True: tokens = self.tokenizer.encode(next(ds)["text"]) buffer += [self.tokenizer.eos_token_id] + tokens while len(buffer) > self.max_length: yield torch.tensor(buffer[: self.max_length]) buffer = buffer[int(self.max_length / self.repeat_factor) :] if __name__ == "__main__": from tqdm import tqdm from torch.utils.data import DataLoader from transformers import GPT2Tokenizer dataset = ThePileTokenized( ThePile("train"), GPT2Tokenizer.from_pretrained("gpt2"), max_length=2048, repeat_factor=4 / 3, ) dataloader = DataLoader( dataset, batch_size=1, ) for batch in tqdm(dataloader, smoothing=0.01): x = 0 # ~6 iters/s for 1 worker