|
import torch |
|
class EssayDataset(torch.utils.data.Dataset): |
|
def __init__(self, dataframe, tokenizer, max_length): |
|
self.data = dataframe |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
text = self.data.iloc[idx]['train_input'] |
|
labels = self.data.iloc[idx]['labels'] |
|
encoding = self.tokenizer( |
|
text, |
|
max_length=self.max_length, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
return { |
|
'input_ids': encoding['input_ids'].flatten(), |
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
'labels': torch.tensor(labels, dtype=torch.float) |
|
} |
|
|
|
|