import logging from typing import List, Mapping, Tuple import pandas as pd import torch from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer from catalyst.utils import set_global_seed class DatasetWrapper(Dataset): def __init__( self, texts: List[str], labels: List[str] = None, label_dict: Mapping[str, int] = None, max_seq_length: int = 512, model_name: str = "distilbert-base-uncased", ): self.texts = texts self.labels = labels self.label_dict = label_dict self.max_seq_length = max_seq_length if self.label_dict is None and labels is not None: self.label_dict = dict(zip(sorted(set(labels)), range(len(set(labels))))) self.tokenizer = AutoTokenizer.from_pretrained(model_name) logging.getLogger("transformers.tokenization_utils").setLevel(logging.FATAL) self.sep_vid = self.tokenizer.vocab["[SEP]"] self.cls_vid = self.tokenizer.vocab["[CLS]"] self.pad_vid = self.tokenizer.vocab["[PAD]"] def __len__(self) -> int: return len(self.texts) def __getitem__(self, index) -> Mapping[str, torch.Tensor]: x = self.texts[index] output_dict = self.tokenizer.encode_plus( x, add_special_tokens=True, padding="max_length", max_length=self.max_seq_length, return_tensors="pt", truncation=True, return_attention_mask=True, ) output_dict["features"] = output_dict["input_ids"].squeeze(0) del output_dict["input_ids"] if self.labels is not None: y = self.labels[index] y_encoded = torch.Tensor([self.label_dict.get(y, -1)]).long().squeeze(0) output_dict["targets"] = y_encoded return output_dict def read_data(text=None) -> Tuple[ dict, dict]: dataset = DatasetWrapper( texts=text, max_seq_length=128, model_name='distilbert-base-uncased', ) set_global_seed(17) loader = { "test": DataLoader( dataset=dataset, batch_size=8, shuffle=False, ) } return loader