import torch from torch.utils.data import Dataset, DataLoader import numpy as np import pytorch_lightning as pl import torch.nn as nn from transformers import BertTokenizerFast as BertTokenizer, AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel from huggingface_hub import PyTorchModelHubMixin class EurovocDataset(Dataset): def __init__( self, text: np.array, labels: np.array, tokenizer: BertTokenizer, max_token_len: int = 128 ): self.tokenizer = tokenizer self.text = text self.labels = labels self.max_token_len = max_token_len def __len__(self): return len(self.labels) def __getitem__(self, index: int): text = self.text[index][0] labels = self.labels[index] encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_token_len, return_token_type_ids=False, padding="max_length", truncation=True, return_attention_mask=True, return_tensors='pt', ) return dict( text=text, input_ids=encoding["input_ids"].flatten(), attention_mask=encoding["attention_mask"].flatten(), labels=torch.FloatTensor(labels) ) class EuroVocLongTextDataset(Dataset): def __splitter__(text, max_lenght): l = text.split() for i in range(0, len(l), max_lenght): yield l[i:i + max_lenght] def __init__( self, text: np.array, labels: np.array, tokenizer: BertTokenizer, max_token_len: int = 128 ): self.tokenizer = tokenizer self.text = text self.labels = labels self.max_token_len = max_token_len self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)] self.encoding = self.tokenizer.batch_encode_plus( [c for c, _ in self.chunks_and_labels], add_special_tokens=True, max_length=self.max_token_len, return_token_type_ids=False, padding="max_length", truncation=True, return_attention_mask=True, return_tensors='pt', ) def __len__(self): return len(self.chunks_and_labels) def __getitem__(self, index: int): text, labels = self.chunks_and_labels[index] return dict( text=text, input_ids=self.encoding[index]["input_ids"].flatten(), attention_mask=self.encoding[index]["attention_mask"].flatten(), labels=torch.FloatTensor(labels) ) class EurovocDataModule(pl.LightningDataModule): def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512): super().__init__() self.batch_size = batch_size self.x_tr = x_tr self.y_tr = y_tr self.x_test = x_test self.y_test = y_test self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name) self.max_token_len = max_token_len def setup(self, stage=None): self.train_dataset = EurovocDataset( self.x_tr, self.y_tr, self.tokenizer, self.max_token_len ) self.test_dataset = EurovocDataset( self.x_test, self.y_test, self.tokenizer, self.max_token_len ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2 ) def val_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=2 ) def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=2 ) class EurovocTagger(pl.LightningModule, PyTorchModelHubMixin): def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8): super().__init__() self.bert = AutoModel.from_pretrained(bert_model_name) self.dropout = nn.Dropout(p=0.2) self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes) self.criterion = nn.BCELoss() self.lr = lr self.eps = eps def forward(self, input_ids, attention_mask, labels=None): output = self.bert(input_ids, attention_mask=attention_mask) output = self.dropout(output.pooler_output) output = self.classifier1(output) output = torch.sigmoid(output) loss = 0 if labels is not None: loss = self.criterion(output, labels) return loss, output def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] loss, outputs = self(input_ids, attention_mask, labels) self.log("train_loss", loss, prog_bar=True, logger=True) return {"loss": loss, "predictions": outputs, "labels": labels} def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] loss, outputs = self(input_ids, attention_mask, labels) self.log("val_loss", loss, prog_bar=True, logger=True) return loss def test_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] labels = batch["labels"] loss, outputs = self(input_ids, attention_mask, labels) self.log("test_loss", loss, prog_bar=True, logger=True) return loss def on_train_epoch_end(self, *args, **kwargs): return #labels = [] #predictions = [] #for output in args['outputs']: # for out_labels in output["labels"].detach().cpu(): # labels.append(out_labels) # for out_predictions in output["predictions"].detach().cpu(): # predictions.append(out_predictions) #labels = torch.stack(labels).int() #predictions = torch.stack(predictions) #for i, name in enumerate(mlb.classes_): # class_roc_auc = auroc(predictions[:, i], labels[:, i]) # self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch) def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)