Spaces:
Build error
Build error
""" | |
@author:jishnuprakash | |
""" | |
# This file consists of constants, attributes and classes used for training | |
import re | |
import nltk | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup | |
import pytorch_lightning as pl | |
from pytorch_lightning.metrics.functional import auroc | |
from nltk.stem import WordNetLemmatizer | |
from nltk.corpus import stopwords | |
from transformers import AutoTokenizer, AutoModel | |
random_seed = 42 | |
num_epochs = 10 | |
batch = 1 | |
threshold = 0.5 | |
max_tokens = 512 | |
clean_text = False | |
# bert_model = "bert-base-uncased" | |
bert_model = "nlpaueb/legal-bert-base-uncased" | |
checkpoint_dir = "checkpoints" | |
check_filename = "legal-full-data" | |
earlystop_monitor = "val_loss" | |
earlystop_patience = 2 | |
lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6", | |
"Article 8", "Article 9", "Article 10", "Article 11", | |
"Article 14", "Article 1 of Protocol 1", "No Violation"] | |
num_classes = len(lex_classes) | |
#Stop words | |
stop_words = stopwords.words("english") | |
lemmatizer = WordNetLemmatizer() | |
def preprocess_text(text, remove_stopwords, stop_words): | |
""" | |
Clean text | |
""" | |
text = text.lower() | |
# remove special chars and numbers | |
text = re.sub("[^A-Za-z]+", " ", text) | |
# remove stopwords | |
if remove_stopwords: | |
# 1. tokenize | |
tokens = nltk.word_tokenize(text) | |
# 2. check if stopword | |
tokens = [w for w in tokens if not w.lower() in stop_words] | |
# 3. Lemmatize | |
tokens = [lemmatizer.lemmatize(i) for i in tokens] | |
# 4. join back together | |
text = " ".join(tokens) | |
# return text in lower case and stripped of whitespaces | |
text = text.lower().strip() | |
return text | |
def preprocess_data(df, clean=False): | |
""" | |
Perform basic data preprocessing | |
""" | |
df = df[df['text'].map(len)>0] | |
df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10]) | |
df.dropna(inplace=True) | |
if clean: | |
df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1) | |
return df | |
class LexGlueDataset(Dataset): | |
""" | |
Lex GLUE Dataset as pytorch dataset | |
""" | |
def __init__(self, data, tokenizer, max_tokens=512): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.data = data | |
self.max_tokens = max_tokens | |
def __len__(self): | |
# return len(self.data) | |
return self.data.__len__() | |
def generateLabels(self, labels): | |
out = [0] * num_classes | |
for i in labels: | |
out[i] = 1 | |
return out | |
def __getitem__(self, index): | |
data_row = self.data.iloc[index] | |
lex_text = data_row.text | |
multi_labels = self.generateLabels(data_row.labels) | |
encoding = self.tokenizer.encode_plus(lex_text, | |
add_special_tokens=True, | |
max_length=self.max_tokens, | |
return_token_type_ids=False, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
is_split_into_words=True, | |
return_tensors='pt',) | |
return dict(text = lex_text, | |
input_ids = encoding["input_ids"].flatten(), | |
attention_mask = encoding["attention_mask"].flatten(), | |
labels = torch.FloatTensor(multi_labels)) | |
class LexGlueDataModule(pl.LightningDataModule): | |
""" | |
Data module to load LexGlueDataset for training, validating and testing | |
""" | |
def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512): | |
super().__init__() | |
self.batch_size = batch_size | |
self.train = train | |
self.test = test | |
self.tokenizer = tokenizer | |
self.max_tokens = max_tokens | |
def setup(self, stage=None): | |
self.train_dataset = LexGlueDataset(self.train, | |
self.tokenizer, | |
self.max_tokens) | |
self.test_dataset = LexGlueDataset(self.test, | |
self.tokenizer, | |
self.max_tokens) | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True,) | |
def val_dataloader(self): | |
return DataLoader(self.test_dataset, | |
batch_size=self.batch_size,) | |
def test_dataloader(self): | |
return DataLoader(self.test_dataset, | |
batch_size=self.batch_size,) | |
class LexGlueTagger(pl.LightningModule): | |
""" | |
Model and Training instance as LexGlueTagger class for Pytorch Lightning module | |
""" | |
def __init__(self, num_classes, training_steps=None, warmup_steps=None): | |
super().__init__() | |
self.bert = AutoModel.from_pretrained(bert_model, return_dict=True) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes) | |
self.training_steps = training_steps | |
self.warmup_steps = warmup_steps | |
self.criterion = nn.BCELoss() | |
def forward(self, input_ids, attention_mask, labels=None): | |
""" | |
Forward pass | |
""" | |
output = self.bert(input_ids, attention_mask=attention_mask) | |
output = self.classifier(output.pooler_output) | |
output = torch.sigmoid(output) | |
loss = 0 | |
if labels is not None: | |
loss = self.criterion(output, labels) | |
return loss, output | |
def training_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, outputs = self(input_ids, attention_mask, labels) | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
return {"loss": loss, "predictions": outputs, "labels": labels} | |
def validation_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, outputs = self(input_ids, attention_mask, labels) | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, outputs = self(input_ids, attention_mask, labels) | |
self.log("test_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def training_epoch_end(self, outputs): | |
labels = [] | |
predictions = [] | |
for output in outputs: | |
for out_labels in output["labels"].detach().cpu(): | |
labels.append(out_labels) | |
for out_predictions in output["predictions"].detach().cpu(): | |
predictions.append(out_predictions) | |
labels = torch.stack(labels).int() | |
predictions = torch.stack(predictions) | |
for i, name in enumerate(lex_classes): | |
class_roc_auc = auroc(predictions[:, i], labels[:, i]) | |
self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch) | |
def configure_optimizers(self): | |
""" | |
Optimizer and Learning rate scheduler | |
""" | |
optimizer = AdamW(self.parameters(), lr=2e-5) | |
scheduler = get_linear_schedule_with_warmup(optimizer, | |
num_warmup_steps=self.warmup_steps, | |
num_training_steps=self.training_steps) | |
return dict(optimizer=optimizer, | |
lr_scheduler=dict(scheduler=scheduler, | |
interval='step')) | |