File size: 3,206 Bytes
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
 
 
245d478
 
 
47c0211
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
8a9b6f0
47c0211
 
 
 
 
 
245d478
47c0211
 
 
 
 
 
 
 
 
245d478
 
 
47c0211
 
 
 
 
 
245d478
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pandas as pd

from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule


class PyTorchDataModule(Dataset):
    """PyTorch Dataset class"""

    def __init__(self, model_name_or_path: str, data: pd.DataFrame, max_seq_length: int = 256):
        """
        Initiates a PyTorch Dataset Module for input data
        """
        self.model_name_or_path = model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
        self.data = data
        self.max_seq_length = max_seq_length

    def __len__(self):
        """returns length of data"""
        return len(self.data)

    def __getitem__(self, index: int):
        """returns dictionary of input tensors to feed into the model"""

        data_row = self.data.iloc[index]
        sentence = data_row["sentence"]

        sentence_encoding = self.tokenizer(
            sentence,
            max_length=self.max_seq_length,
            padding="max_length",
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        
        out = dict(
            sentence=sentence,
            input_ids=sentence_encoding["input_ids"].flatten(),
            attention_mask=sentence_encoding["attention_mask"].flatten(),
        )
        
        if "label" in self.data.columns:
            out.update(dict(labels=data_row["label"].flatten()))

        return out


class DataModule(LightningDataModule):
    def __init__(
        self,
        model_name_or_path: str,
        train_df: pd.DataFrame,
        eval_df: pd.DataFrame,
        test_df: pd.DataFrame,
        max_seq_length: int = 256,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        num_workers: int = 4,
        **kwargs
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.train_df = train_df
        self.eval_df = eval_df
        self.test_df = test_df
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):

        self.train_dataset = PyTorchDataModule(self.model_name_or_path, self.train_df, self.max_seq_length)
        self.eval_dataset = PyTorchDataModule(self.model_name_or_path, self.eval_df, self.max_seq_length)
        
        if isinstance(self.test_df, pd.DataFrame):
            self.test_dataset = PyTorchDataModule(self.model_name_or_path, self.test_df, self.max_seq_length)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.eval_dataset, batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False, num_workers=self.num_workers, pin_memory=True)