Patridev's picture
add app
4c8fe65
import logging
from typing import List, Mapping, Tuple
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from catalyst.utils import set_global_seed
class DatasetWrapper(Dataset):
def __init__(
self,
texts: List[str],
labels: List[str] = None,
label_dict: Mapping[str, int] = None,
max_seq_length: int = 512,
model_name: str = "distilbert-base-uncased",
):
self.texts = texts
self.labels = labels
self.label_dict = label_dict
self.max_seq_length = max_seq_length
if self.label_dict is None and labels is not None:
self.label_dict = dict(zip(sorted(set(labels)), range(len(set(labels)))))
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.FATAL)
self.sep_vid = self.tokenizer.vocab["[SEP]"]
self.cls_vid = self.tokenizer.vocab["[CLS]"]
self.pad_vid = self.tokenizer.vocab["[PAD]"]
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, index) -> Mapping[str, torch.Tensor]:
x = self.texts[index]
output_dict = self.tokenizer.encode_plus(
x,
add_special_tokens=True,
padding="max_length",
max_length=self.max_seq_length,
return_tensors="pt",
truncation=True,
return_attention_mask=True,
)
output_dict["features"] = output_dict["input_ids"].squeeze(0)
del output_dict["input_ids"]
if self.labels is not None:
y = self.labels[index]
y_encoded = torch.Tensor([self.label_dict.get(y, -1)]).long().squeeze(0)
output_dict["targets"] = y_encoded
return output_dict
def read_data(text=None) -> Tuple[
dict, dict]:
dataset = DatasetWrapper(
texts=text,
max_seq_length=128,
model_name='distilbert-base-uncased',
)
set_global_seed(17)
loader = {
"test": DataLoader(
dataset=dataset,
batch_size=8,
shuffle=False,
)
}
return loader