mini-omni-s2s / slam_llm /data /concatenator.py
xcczach's picture
Upload 73 files
35c1cfd verified
raw
history blame
1.1 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from tqdm import tqdm
from itertools import chain
from torch.utils.data import Dataset
class ConcatDataset(Dataset):
def __init__(self, dataset, chunk_size=4096):
self.dataset = dataset
self.chunk_size = chunk_size
self.samples = []
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
buffer = {k: v + sample[k] for k,v in buffer.items()}
while len(next(iter(buffer.values()))) > self.chunk_size:
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
def __getitem__(self, idx):
return self.samples[idx]
def __len__(self):
return len(self.samples)