Spaces:
Build error
Build error
| """ | |
| @author:jishnuprakash | |
| """ | |
| # This file consists of constants, attributes and classes used for training | |
| import re | |
| import nltk | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.metrics.functional import auroc | |
| from nltk.stem import WordNetLemmatizer | |
| from nltk.corpus import stopwords | |
| from transformers import AutoTokenizer, AutoModel | |
| random_seed = 42 | |
| num_epochs = 10 | |
| batch = 1 | |
| threshold = 0.5 | |
| max_tokens = 512 | |
| clean_text = False | |
| # bert_model = "bert-base-uncased" | |
| bert_model = "nlpaueb/legal-bert-base-uncased" | |
| checkpoint_dir = "checkpoints" | |
| check_filename = "legal-full-data" | |
| earlystop_monitor = "val_loss" | |
| earlystop_patience = 2 | |
| lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6", | |
| "Article 8", "Article 9", "Article 10", "Article 11", | |
| "Article 14", "Article 1 of Protocol 1", "No Violation"] | |
| num_classes = len(lex_classes) | |
| #Stop words | |
| stop_words = stopwords.words("english") | |
| lemmatizer = WordNetLemmatizer() | |
| def preprocess_text(text, remove_stopwords, stop_words): | |
| """ | |
| Clean text | |
| """ | |
| text = text.lower() | |
| # remove special chars and numbers | |
| text = re.sub("[^A-Za-z]+", " ", text) | |
| # remove stopwords | |
| if remove_stopwords: | |
| # 1. tokenize | |
| tokens = nltk.word_tokenize(text) | |
| # 2. check if stopword | |
| tokens = [w for w in tokens if not w.lower() in stop_words] | |
| # 3. Lemmatize | |
| tokens = [lemmatizer.lemmatize(i) for i in tokens] | |
| # 4. join back together | |
| text = " ".join(tokens) | |
| # return text in lower case and stripped of whitespaces | |
| text = text.lower().strip() | |
| return text | |
| def preprocess_data(df, clean=False): | |
| """ | |
| Perform basic data preprocessing | |
| """ | |
| df = df[df['text'].map(len)>0] | |
| df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10]) | |
| df.dropna(inplace=True) | |
| if clean: | |
| df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1) | |
| return df | |
| class LexGlueDataset(Dataset): | |
| """ | |
| Lex GLUE Dataset as pytorch dataset | |
| """ | |
| def __init__(self, data, tokenizer, max_tokens=512): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.data = data | |
| self.max_tokens = max_tokens | |
| def __len__(self): | |
| # return len(self.data) | |
| return self.data.__len__() | |
| def generateLabels(self, labels): | |
| out = [0] * num_classes | |
| for i in labels: | |
| out[i] = 1 | |
| return out | |
| def __getitem__(self, index): | |
| data_row = self.data.iloc[index] | |
| lex_text = data_row.text | |
| multi_labels = self.generateLabels(data_row.labels) | |
| encoding = self.tokenizer.encode_plus(lex_text, | |
| add_special_tokens=True, | |
| max_length=self.max_tokens, | |
| return_token_type_ids=False, | |
| padding="max_length", | |
| truncation=True, | |
| return_attention_mask=True, | |
| is_split_into_words=True, | |
| return_tensors='pt',) | |
| return dict(text = lex_text, | |
| input_ids = encoding["input_ids"].flatten(), | |
| attention_mask = encoding["attention_mask"].flatten(), | |
| labels = torch.FloatTensor(multi_labels)) | |
| class LexGlueDataModule(pl.LightningDataModule): | |
| """ | |
| Data module to load LexGlueDataset for training, validating and testing | |
| """ | |
| def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.train = train | |
| self.test = test | |
| self.tokenizer = tokenizer | |
| self.max_tokens = max_tokens | |
| def setup(self, stage=None): | |
| self.train_dataset = LexGlueDataset(self.train, | |
| self.tokenizer, | |
| self.max_tokens) | |
| self.test_dataset = LexGlueDataset(self.test, | |
| self.tokenizer, | |
| self.max_tokens) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True,) | |
| def val_dataloader(self): | |
| return DataLoader(self.test_dataset, | |
| batch_size=self.batch_size,) | |
| def test_dataloader(self): | |
| return DataLoader(self.test_dataset, | |
| batch_size=self.batch_size,) | |
| class LexGlueTagger(pl.LightningModule): | |
| """ | |
| Model and Training instance as LexGlueTagger class for Pytorch Lightning module | |
| """ | |
| def __init__(self, num_classes, training_steps=None, warmup_steps=None): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(bert_model, return_dict=True) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes) | |
| self.training_steps = training_steps | |
| self.warmup_steps = warmup_steps | |
| self.criterion = nn.BCELoss() | |
| def forward(self, input_ids, attention_mask, labels=None): | |
| """ | |
| Forward pass | |
| """ | |
| output = self.bert(input_ids, attention_mask=attention_mask) | |
| output = self.classifier(output.pooler_output) | |
| output = torch.sigmoid(output) | |
| loss = 0 | |
| if labels is not None: | |
| loss = self.criterion(output, labels) | |
| return loss, output | |
| def training_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels = batch["labels"] | |
| loss, outputs = self(input_ids, attention_mask, labels) | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| return {"loss": loss, "predictions": outputs, "labels": labels} | |
| def validation_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels = batch["labels"] | |
| loss, outputs = self(input_ids, attention_mask, labels) | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels = batch["labels"] | |
| loss, outputs = self(input_ids, attention_mask, labels) | |
| self.log("test_loss", loss, prog_bar=True, logger=True) | |
| return loss | |
| def training_epoch_end(self, outputs): | |
| labels = [] | |
| predictions = [] | |
| for output in outputs: | |
| for out_labels in output["labels"].detach().cpu(): | |
| labels.append(out_labels) | |
| for out_predictions in output["predictions"].detach().cpu(): | |
| predictions.append(out_predictions) | |
| labels = torch.stack(labels).int() | |
| predictions = torch.stack(predictions) | |
| for i, name in enumerate(lex_classes): | |
| class_roc_auc = auroc(predictions[:, i], labels[:, i]) | |
| self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch) | |
| def configure_optimizers(self): | |
| """ | |
| Optimizer and Learning rate scheduler | |
| """ | |
| optimizer = AdamW(self.parameters(), lr=2e-5) | |
| scheduler = get_linear_schedule_with_warmup(optimizer, | |
| num_warmup_steps=self.warmup_steps, | |
| num_training_steps=self.training_steps) | |
| return dict(optimizer=optimizer, | |
| lr_scheduler=dict(scheduler=scheduler, | |
| interval='step')) | |