lex_glue_ecthrA / utils.py
jishnuprakash's picture
utils and ui
75a0868
raw
history blame
8.14 kB
"""
@author:jishnuprakash
"""
# This file consists of constants, attributes and classes used for training
import re
import nltk
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import auroc
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModel
random_seed = 42
num_epochs = 10
batch = 1
threshold = 0.5
max_tokens = 512
clean_text = False
# bert_model = "bert-base-uncased"
bert_model = "nlpaueb/legal-bert-base-uncased"
checkpoint_dir = "checkpoints"
check_filename = "legal-full-data"
earlystop_monitor = "val_loss"
earlystop_patience = 2
lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6",
"Article 8", "Article 9", "Article 10", "Article 11",
"Article 14", "Article 1 of Protocol 1", "No Violation"]
num_classes = len(lex_classes)
#Stop words
stop_words = stopwords.words("english")
lemmatizer = WordNetLemmatizer()
def preprocess_text(text, remove_stopwords, stop_words):
"""
Clean text
"""
text = text.lower()
# remove special chars and numbers
text = re.sub("[^A-Za-z]+", " ", text)
# remove stopwords
if remove_stopwords:
# 1. tokenize
tokens = nltk.word_tokenize(text)
# 2. check if stopword
tokens = [w for w in tokens if not w.lower() in stop_words]
# 3. Lemmatize
tokens = [lemmatizer.lemmatize(i) for i in tokens]
# 4. join back together
text = " ".join(tokens)
# return text in lower case and stripped of whitespaces
text = text.lower().strip()
return text
def preprocess_data(df, clean=False):
"""
Perform basic data preprocessing
"""
df = df[df['text'].map(len)>0]
df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10])
df.dropna(inplace=True)
if clean:
df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1)
return df
class LexGlueDataset(Dataset):
"""
Lex GLUE Dataset as pytorch dataset
"""
def __init__(self, data, tokenizer, max_tokens=512):
super().__init__()
self.tokenizer = tokenizer
self.data = data
self.max_tokens = max_tokens
def __len__(self):
# return len(self.data)
return self.data.__len__()
def generateLabels(self, labels):
out = [0] * num_classes
for i in labels:
out[i] = 1
return out
def __getitem__(self, index):
data_row = self.data.iloc[index]
lex_text = data_row.text
multi_labels = self.generateLabels(data_row.labels)
encoding = self.tokenizer.encode_plus(lex_text,
add_special_tokens=True,
max_length=self.max_tokens,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
is_split_into_words=True,
return_tensors='pt',)
return dict(text = lex_text,
input_ids = encoding["input_ids"].flatten(),
attention_mask = encoding["attention_mask"].flatten(),
labels = torch.FloatTensor(multi_labels))
class LexGlueDataModule(pl.LightningDataModule):
"""
Data module to load LexGlueDataset for training, validating and testing
"""
def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512):
super().__init__()
self.batch_size = batch_size
self.train = train
self.test = test
self.tokenizer = tokenizer
self.max_tokens = max_tokens
def setup(self, stage=None):
self.train_dataset = LexGlueDataset(self.train,
self.tokenizer,
self.max_tokens)
self.test_dataset = LexGlueDataset(self.test,
self.tokenizer,
self.max_tokens)
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,)
def val_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,)
def test_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,)
class LexGlueTagger(pl.LightningModule):
"""
Model and Training instance as LexGlueTagger class for Pytorch Lightning module
"""
def __init__(self, num_classes, training_steps=None, warmup_steps=None):
super().__init__()
self.bert = AutoModel.from_pretrained(bert_model, return_dict=True)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
self.training_steps = training_steps
self.warmup_steps = warmup_steps
self.criterion = nn.BCELoss()
def forward(self, input_ids, attention_mask, labels=None):
"""
Forward pass
"""
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.pooler_output)
output = torch.sigmoid(output)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
loss, outputs = self(input_ids, attention_mask, labels)
self.log("train_loss", loss, prog_bar=True, logger=True)
return {"loss": loss, "predictions": outputs, "labels": labels}
def validation_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
loss, outputs = self(input_ids, attention_mask, labels)
self.log("val_loss", loss, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
loss, outputs = self(input_ids, attention_mask, labels)
self.log("test_loss", loss, prog_bar=True, logger=True)
return loss
def training_epoch_end(self, outputs):
labels = []
predictions = []
for output in outputs:
for out_labels in output["labels"].detach().cpu():
labels.append(out_labels)
for out_predictions in output["predictions"].detach().cpu():
predictions.append(out_predictions)
labels = torch.stack(labels).int()
predictions = torch.stack(predictions)
for i, name in enumerate(lex_classes):
class_roc_auc = auroc(predictions[:, i], labels[:, i])
self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
def configure_optimizers(self):
"""
Optimizer and Learning rate scheduler
"""
optimizer = AdamW(self.parameters(), lr=2e-5)
scheduler = get_linear_schedule_with_warmup(optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=self.training_steps)
return dict(optimizer=optimizer,
lr_scheduler=dict(scheduler=scheduler,
interval='step'))