HateGuard / dataset.py
DataRaptor's picture
Upload 7 files
38f2246
raw
history blame contribute delete
No virus
1.38 kB
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
from torch.utils.data import Dataset
from utils import read_yaml
class BanglaHSDataset(Dataset):
def __init__(self, tokenizer, max_length):
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self): return 0
def __getitem__(self, text):
inputs = self.tokenizer(
text,
max_length=self.max_length, padding='max_length',
truncation=True,
return_offsets_mapping=False
)
for k, v in inputs.items(): inputs[k] = torch.tensor(v, dtype=torch.long).unsqueeze(dim=0)
label = torch.tensor(0, dtype=torch.float)
return inputs, label
def get_class(index):
ind2cat = [
'Geopolitical',
'Personal',
'Political',
'Religious',
]
return ind2cat[index]
if __name__ == '__main__':
cfg = read_yaml('./baseline.yaml')
# cfg.Model.target_size = 6
# model = BanglaHS_Model(cfg.Model)
# #model.load_state_dict(torch.load('./model_fold-0_best.pt', map_location=torch.device('cpu')))
# model.eval()
# ds = BanglaHSDataset(cfg.Dataset, model)
# x = ds['Hello hi'][0]
# with torch.no_grad():
# y = model(x)
# print('y:', y)