|
import pandas as pd |
|
from transformers import AutoTokenizer |
|
|
|
class MyDataset: |
|
def __init__(self, data_file, tokenizer): |
|
self.data = pd.read_csv(data_file) |
|
self.tokenizer = tokenizer |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
text = self.data.iloc[idx, 0] |
|
agents = self.data.iloc[idx, 1] |
|
actions = self.data.iloc[idx, 2] |
|
|
|
encoding = self.tokenizer.encode_plus( |
|
text, |
|
max_length=512, |
|
padding='max_length', |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt' |
|
) |
|
|
|
return { |
|
'input_ids': encoding['input_ids'].flatten(), |
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
'labels_agents': torch.tensor(agents), |
|
'labels_actions': torch.tensor(actions) |
|
} |