File size: 8,144 Bytes
75a0868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""
@author:jishnuprakash
"""
# This file consists of constants, attributes and classes used for training
import re
import nltk
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import auroc
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModel

random_seed = 42
num_epochs = 10
batch = 1
threshold = 0.5
max_tokens = 512
clean_text = False

# bert_model = "bert-base-uncased"
bert_model = "nlpaueb/legal-bert-base-uncased"
checkpoint_dir = "checkpoints"
check_filename = "legal-full-data"

earlystop_monitor = "val_loss"
earlystop_patience = 2

lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6", 
               "Article 8", "Article 9", "Article 10", "Article 11",
               "Article 14", "Article 1 of Protocol 1", "No Violation"]

num_classes = len(lex_classes)

#Stop words
stop_words = stopwords.words("english")
lemmatizer = WordNetLemmatizer()

def preprocess_text(text, remove_stopwords, stop_words):
    """
    Clean text
    """
    text = text.lower()
    # remove special chars and numbers
    text = re.sub("[^A-Za-z]+", " ", text)
    # remove stopwords
    if remove_stopwords:
        # 1. tokenize
        tokens = nltk.word_tokenize(text)
        # 2. check if stopword
        tokens = [w for w in tokens if not w.lower() in stop_words]
        # 3. Lemmatize
        tokens = [lemmatizer.lemmatize(i) for i in tokens]
        # 4. join back together
        text = " ".join(tokens)
    # return text in lower case and stripped of whitespaces
    text = text.lower().strip()
    return text

def preprocess_data(df, clean=False):
    """
    Perform basic data preprocessing
    """
    df = df[df['text'].map(len)>0]
    df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10])
    df.dropna(inplace=True)
    if clean:
        df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1)
    return df

class LexGlueDataset(Dataset):
    """
    Lex GLUE Dataset as pytorch dataset
    """

    def __init__(self, data, tokenizer, max_tokens=512):
        super().__init__()
        self.tokenizer = tokenizer
        self.data = data
        self.max_tokens = max_tokens

    def __len__(self):
        # return len(self.data)
        return self.data.__len__()
    
    def generateLabels(self, labels):
        out = [0] * num_classes
        for i in labels:
            out[i] = 1
        return out

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        lex_text = data_row.text
        multi_labels = self.generateLabels(data_row.labels)

        encoding = self.tokenizer.encode_plus(lex_text,
                                              add_special_tokens=True,
                                              max_length=self.max_tokens,
                                              return_token_type_ids=False,
                                              padding="max_length",
                                              truncation=True,
                                              return_attention_mask=True,
                                              is_split_into_words=True,
                                              return_tensors='pt',)
        
        return dict(text = lex_text,
                    input_ids = encoding["input_ids"].flatten(),
                    attention_mask = encoding["attention_mask"].flatten(),
                    labels = torch.FloatTensor(multi_labels))


class LexGlueDataModule(pl.LightningDataModule):
    """
    Data module to load LexGlueDataset for training, validating and testing
    """

    def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512):
        super().__init__()
        self.batch_size = batch_size
        self.train = train
        self.test = test
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens

    def setup(self, stage=None):
        self.train_dataset = LexGlueDataset(self.train, 
                                            self.tokenizer,
                                            self.max_tokens)

        self.test_dataset = LexGlueDataset(self.test, 
                                           self.tokenizer,
                                           self.max_tokens)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size,
                          shuffle=True,)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size,)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,)


class LexGlueTagger(pl.LightningModule):
    """
    Model and Training instance as LexGlueTagger class for Pytorch Lightning module 
    """

    def __init__(self, num_classes, training_steps=None, warmup_steps=None):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model, return_dict=True)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.training_steps = training_steps
        self.warmup_steps = warmup_steps
        self.criterion = nn.BCELoss()
    
    def forward(self, input_ids, attention_mask, labels=None):
        """
        Forward pass
        """
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.classifier(output.pooler_output)
        output = torch.sigmoid(output)    
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs, "labels": labels}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def training_epoch_end(self, outputs):
        labels = []
        predictions = []
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)
        
        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)

        for i, name in enumerate(lex_classes):
            class_roc_auc = auroc(predictions[:, i], labels[:, i])
            self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
        

    def configure_optimizers(self):
        """
        Optimizer and Learning rate scheduler
        """
        optimizer = AdamW(self.parameters(), lr=2e-5)
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=self.warmup_steps,
                                                    num_training_steps=self.training_steps)
        return dict(optimizer=optimizer,
                    lr_scheduler=dict(scheduler=scheduler,
                                      interval='step'))