Spaces:
Build error
Build error
File size: 2,263 Bytes
4c8fe65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import logging
from typing import List, Mapping, Tuple
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from catalyst.utils import set_global_seed
class DatasetWrapper(Dataset):
def __init__(
self,
texts: List[str],
labels: List[str] = None,
label_dict: Mapping[str, int] = None,
max_seq_length: int = 512,
model_name: str = "distilbert-base-uncased",
):
self.texts = texts
self.labels = labels
self.label_dict = label_dict
self.max_seq_length = max_seq_length
if self.label_dict is None and labels is not None:
self.label_dict = dict(zip(sorted(set(labels)), range(len(set(labels)))))
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.FATAL)
self.sep_vid = self.tokenizer.vocab["[SEP]"]
self.cls_vid = self.tokenizer.vocab["[CLS]"]
self.pad_vid = self.tokenizer.vocab["[PAD]"]
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, index) -> Mapping[str, torch.Tensor]:
x = self.texts[index]
output_dict = self.tokenizer.encode_plus(
x,
add_special_tokens=True,
padding="max_length",
max_length=self.max_seq_length,
return_tensors="pt",
truncation=True,
return_attention_mask=True,
)
output_dict["features"] = output_dict["input_ids"].squeeze(0)
del output_dict["input_ids"]
if self.labels is not None:
y = self.labels[index]
y_encoded = torch.Tensor([self.label_dict.get(y, -1)]).long().squeeze(0)
output_dict["targets"] = y_encoded
return output_dict
def read_data(text=None) -> Tuple[
dict, dict]:
dataset = DatasetWrapper(
texts=text,
max_seq_length=128,
model_name='distilbert-base-uncased',
)
set_global_seed(17)
loader = {
"test": DataLoader(
dataset=dataset,
batch_size=8,
shuffle=False,
)
}
return loader |