SimpleAES / data /dataset.py
SFM2001's picture
upload files
4f591e5
raw
history blame contribute delete
833 Bytes
import torch
class EssayDataset(torch.utils.data.Dataset):
def __init__(self, dataframe, tokenizer, max_length):
self.data = dataframe
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data.iloc[idx]['train_input']
labels = self.data.iloc[idx]['labels']
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(labels, dtype=torch.float)
}