jishnuprakash commited on
Commit
75a0868
1 Parent(s): 6da871e

utils and ui

Browse files
Files changed (1) hide show
  1. utils.py +230 -0
utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author:jishnuprakash
3
+ """
4
+ # This file consists of constants, attributes and classes used for training
5
+ import re
6
+ import nltk
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
11
+
12
+ import pytorch_lightning as pl
13
+ from pytorch_lightning.metrics.functional import auroc
14
+ from nltk.stem import WordNetLemmatizer
15
+ from nltk.corpus import stopwords
16
+ from transformers import AutoTokenizer, AutoModel
17
+
18
+ random_seed = 42
19
+ num_epochs = 10
20
+ batch = 1
21
+ threshold = 0.5
22
+ max_tokens = 512
23
+ clean_text = False
24
+
25
+ # bert_model = "bert-base-uncased"
26
+ bert_model = "nlpaueb/legal-bert-base-uncased"
27
+ checkpoint_dir = "checkpoints"
28
+ check_filename = "legal-full-data"
29
+
30
+ earlystop_monitor = "val_loss"
31
+ earlystop_patience = 2
32
+
33
+ lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6",
34
+ "Article 8", "Article 9", "Article 10", "Article 11",
35
+ "Article 14", "Article 1 of Protocol 1", "No Violation"]
36
+
37
+ num_classes = len(lex_classes)
38
+
39
+ #Stop words
40
+ stop_words = stopwords.words("english")
41
+ lemmatizer = WordNetLemmatizer()
42
+
43
+ def preprocess_text(text, remove_stopwords, stop_words):
44
+ """
45
+ Clean text
46
+ """
47
+ text = text.lower()
48
+ # remove special chars and numbers
49
+ text = re.sub("[^A-Za-z]+", " ", text)
50
+ # remove stopwords
51
+ if remove_stopwords:
52
+ # 1. tokenize
53
+ tokens = nltk.word_tokenize(text)
54
+ # 2. check if stopword
55
+ tokens = [w for w in tokens if not w.lower() in stop_words]
56
+ # 3. Lemmatize
57
+ tokens = [lemmatizer.lemmatize(i) for i in tokens]
58
+ # 4. join back together
59
+ text = " ".join(tokens)
60
+ # return text in lower case and stripped of whitespaces
61
+ text = text.lower().strip()
62
+ return text
63
+
64
+ def preprocess_data(df, clean=False):
65
+ """
66
+ Perform basic data preprocessing
67
+ """
68
+ df = df[df['text'].map(len)>0]
69
+ df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10])
70
+ df.dropna(inplace=True)
71
+ if clean:
72
+ df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1)
73
+ return df
74
+
75
+ class LexGlueDataset(Dataset):
76
+ """
77
+ Lex GLUE Dataset as pytorch dataset
78
+ """
79
+
80
+ def __init__(self, data, tokenizer, max_tokens=512):
81
+ super().__init__()
82
+ self.tokenizer = tokenizer
83
+ self.data = data
84
+ self.max_tokens = max_tokens
85
+
86
+ def __len__(self):
87
+ # return len(self.data)
88
+ return self.data.__len__()
89
+
90
+ def generateLabels(self, labels):
91
+ out = [0] * num_classes
92
+ for i in labels:
93
+ out[i] = 1
94
+ return out
95
+
96
+ def __getitem__(self, index):
97
+ data_row = self.data.iloc[index]
98
+ lex_text = data_row.text
99
+ multi_labels = self.generateLabels(data_row.labels)
100
+
101
+ encoding = self.tokenizer.encode_plus(lex_text,
102
+ add_special_tokens=True,
103
+ max_length=self.max_tokens,
104
+ return_token_type_ids=False,
105
+ padding="max_length",
106
+ truncation=True,
107
+ return_attention_mask=True,
108
+ is_split_into_words=True,
109
+ return_tensors='pt',)
110
+
111
+ return dict(text = lex_text,
112
+ input_ids = encoding["input_ids"].flatten(),
113
+ attention_mask = encoding["attention_mask"].flatten(),
114
+ labels = torch.FloatTensor(multi_labels))
115
+
116
+
117
+ class LexGlueDataModule(pl.LightningDataModule):
118
+ """
119
+ Data module to load LexGlueDataset for training, validating and testing
120
+ """
121
+
122
+ def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512):
123
+ super().__init__()
124
+ self.batch_size = batch_size
125
+ self.train = train
126
+ self.test = test
127
+ self.tokenizer = tokenizer
128
+ self.max_tokens = max_tokens
129
+
130
+ def setup(self, stage=None):
131
+ self.train_dataset = LexGlueDataset(self.train,
132
+ self.tokenizer,
133
+ self.max_tokens)
134
+
135
+ self.test_dataset = LexGlueDataset(self.test,
136
+ self.tokenizer,
137
+ self.max_tokens)
138
+
139
+ def train_dataloader(self):
140
+ return DataLoader(self.train_dataset,
141
+ batch_size=self.batch_size,
142
+ shuffle=True,)
143
+
144
+ def val_dataloader(self):
145
+ return DataLoader(self.test_dataset,
146
+ batch_size=self.batch_size,)
147
+
148
+ def test_dataloader(self):
149
+ return DataLoader(self.test_dataset,
150
+ batch_size=self.batch_size,)
151
+
152
+
153
+ class LexGlueTagger(pl.LightningModule):
154
+ """
155
+ Model and Training instance as LexGlueTagger class for Pytorch Lightning module
156
+ """
157
+
158
+ def __init__(self, num_classes, training_steps=None, warmup_steps=None):
159
+ super().__init__()
160
+ self.bert = AutoModel.from_pretrained(bert_model, return_dict=True)
161
+ self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
162
+ self.training_steps = training_steps
163
+ self.warmup_steps = warmup_steps
164
+ self.criterion = nn.BCELoss()
165
+
166
+ def forward(self, input_ids, attention_mask, labels=None):
167
+ """
168
+ Forward pass
169
+ """
170
+ output = self.bert(input_ids, attention_mask=attention_mask)
171
+ output = self.classifier(output.pooler_output)
172
+ output = torch.sigmoid(output)
173
+ loss = 0
174
+ if labels is not None:
175
+ loss = self.criterion(output, labels)
176
+ return loss, output
177
+
178
+ def training_step(self, batch, batch_idx):
179
+ input_ids = batch["input_ids"]
180
+ attention_mask = batch["attention_mask"]
181
+ labels = batch["labels"]
182
+ loss, outputs = self(input_ids, attention_mask, labels)
183
+ self.log("train_loss", loss, prog_bar=True, logger=True)
184
+ return {"loss": loss, "predictions": outputs, "labels": labels}
185
+
186
+ def validation_step(self, batch, batch_idx):
187
+ input_ids = batch["input_ids"]
188
+ attention_mask = batch["attention_mask"]
189
+ labels = batch["labels"]
190
+ loss, outputs = self(input_ids, attention_mask, labels)
191
+ self.log("val_loss", loss, prog_bar=True, logger=True)
192
+ return loss
193
+
194
+ def test_step(self, batch, batch_idx):
195
+ input_ids = batch["input_ids"]
196
+ attention_mask = batch["attention_mask"]
197
+ labels = batch["labels"]
198
+ loss, outputs = self(input_ids, attention_mask, labels)
199
+ self.log("test_loss", loss, prog_bar=True, logger=True)
200
+ return loss
201
+
202
+ def training_epoch_end(self, outputs):
203
+ labels = []
204
+ predictions = []
205
+ for output in outputs:
206
+ for out_labels in output["labels"].detach().cpu():
207
+ labels.append(out_labels)
208
+ for out_predictions in output["predictions"].detach().cpu():
209
+ predictions.append(out_predictions)
210
+
211
+ labels = torch.stack(labels).int()
212
+ predictions = torch.stack(predictions)
213
+
214
+ for i, name in enumerate(lex_classes):
215
+ class_roc_auc = auroc(predictions[:, i], labels[:, i])
216
+ self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
217
+
218
+
219
+ def configure_optimizers(self):
220
+ """
221
+ Optimizer and Learning rate scheduler
222
+ """
223
+ optimizer = AdamW(self.parameters(), lr=2e-5)
224
+ scheduler = get_linear_schedule_with_warmup(optimizer,
225
+ num_warmup_steps=self.warmup_steps,
226
+ num_training_steps=self.training_steps)
227
+ return dict(optimizer=optimizer,
228
+ lr_scheduler=dict(scheduler=scheduler,
229
+ interval='step'))
230
+