Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
from tqdm import tqdm | |
from itertools import chain | |
from torch.utils.data import Dataset | |
class ConcatDataset(Dataset): | |
def __init__(self, dataset, chunk_size=4096): | |
self.dataset = dataset | |
self.chunk_size = chunk_size | |
self.samples = [] | |
buffer = { | |
"input_ids": [], | |
"attention_mask": [], | |
"labels": [], | |
} | |
for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): | |
buffer = {k: v + sample[k] for k,v in buffer.items()} | |
while len(next(iter(buffer.values()))) > self.chunk_size: | |
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) | |
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} | |
def __getitem__(self, idx): | |
return self.samples[idx] | |
def __len__(self): | |
return len(self.samples) | |