Spaces:
Build error
Build error
import logging | |
from typing import List, Mapping, Tuple | |
import pandas as pd | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import AutoTokenizer | |
from catalyst.utils import set_global_seed | |
class DatasetWrapper(Dataset): | |
def __init__( | |
self, | |
texts: List[str], | |
labels: List[str] = None, | |
label_dict: Mapping[str, int] = None, | |
max_seq_length: int = 512, | |
model_name: str = "distilbert-base-uncased", | |
): | |
self.texts = texts | |
self.labels = labels | |
self.label_dict = label_dict | |
self.max_seq_length = max_seq_length | |
if self.label_dict is None and labels is not None: | |
self.label_dict = dict(zip(sorted(set(labels)), range(len(set(labels))))) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logging.getLogger("transformers.tokenization_utils").setLevel(logging.FATAL) | |
self.sep_vid = self.tokenizer.vocab["[SEP]"] | |
self.cls_vid = self.tokenizer.vocab["[CLS]"] | |
self.pad_vid = self.tokenizer.vocab["[PAD]"] | |
def __len__(self) -> int: | |
return len(self.texts) | |
def __getitem__(self, index) -> Mapping[str, torch.Tensor]: | |
x = self.texts[index] | |
output_dict = self.tokenizer.encode_plus( | |
x, | |
add_special_tokens=True, | |
padding="max_length", | |
max_length=self.max_seq_length, | |
return_tensors="pt", | |
truncation=True, | |
return_attention_mask=True, | |
) | |
output_dict["features"] = output_dict["input_ids"].squeeze(0) | |
del output_dict["input_ids"] | |
if self.labels is not None: | |
y = self.labels[index] | |
y_encoded = torch.Tensor([self.label_dict.get(y, -1)]).long().squeeze(0) | |
output_dict["targets"] = y_encoded | |
return output_dict | |
def read_data(text=None) -> Tuple[ | |
dict, dict]: | |
dataset = DatasetWrapper( | |
texts=text, | |
max_seq_length=128, | |
model_name='distilbert-base-uncased', | |
) | |
set_global_seed(17) | |
loader = { | |
"test": DataLoader( | |
dataset=dataset, | |
batch_size=8, | |
shuffle=False, | |
) | |
} | |
return loader |