import torch.nn as nn from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor class DocFormerForClassification(nn.Module): def __init__(self, config): super(DocFormerForClassification, self).__init__() self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings']) self.embeddings = DocFormerEmbeddings(config) self.lang_emb = LanguageFeatureExtractor() self.config = config self.dropout = nn.Dropout(config['hidden_dropout_prob']) self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = 16) ## Number of Classes self.encoder = DocFormerEncoder(config) def forward(self, batch_dict): x_feat = batch_dict['x_features'] y_feat = batch_dict['y_features'] token = batch_dict['input_ids'] img = batch_dict['resized_scaled_img'] v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat) v_bar = self.resnet(img) t_bar = self.lang_emb(token) out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s) out = self.linear_layer(out) out = out[:, 0, :] return out ## Defining pytorch lightning model import pytorch_lightning as pl from sklearn.metrics import accuracy_score, confusion_matrix import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np import torchmetrics import wandb import torch class DocFormer(pl.LightningModule): def __init__(self, config , lr = 5e-5): super(DocFormer, self).__init__() self.save_hyperparameters() self.config = config self.docformer = DocFormerForClassification(config) self.num_classes = 16 self.train_accuracy_metric = torchmetrics.Accuracy() self.val_accuracy_metric = torchmetrics.Accuracy() self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes) self.precision_macro_metric = torchmetrics.Precision( average="macro", num_classes=self.num_classes ) self.recall_macro_metric = torchmetrics.Recall( average="macro", num_classes=self.num_classes ) self.precision_micro_metric = torchmetrics.Precision(average="micro") self.recall_micro_metric = torchmetrics.Recall(average="micro") def forward(self, batch_dict): logits = self.docformer(batch_dict) return logits def training_step(self, batch, batch_idx): logits = self.forward(batch) loss = nn.CrossEntropyLoss()(logits, batch['label']) preds = torch.argmax(logits, 1) ## Calculating the accuracy score train_acc = self.train_accuracy_metric(preds, batch["label"]) ## Logging self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True) self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True) return loss def validation_step(self, batch, batch_idx): logits = self.forward(batch) loss = nn.CrossEntropyLoss()(logits, batch['label']) preds = torch.argmax(logits, 1) labels = batch['label'] # Metrics valid_acc = self.val_accuracy_metric(preds, labels) precision_macro = self.precision_macro_metric(preds, labels) recall_macro = self.recall_macro_metric(preds, labels) precision_micro = self.precision_micro_metric(preds, labels) recall_micro = self.recall_micro_metric(preds, labels) f1 = self.f1_metric(preds, labels) # Logging metrics self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True) self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True) self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) self.log("valid/f1", f1, prog_bar=True, on_epoch=True) return {"label": batch['label'], "logits": logits} def validation_epoch_end(self, outputs): labels = torch.cat([x["label"] for x in outputs]) logits = torch.cat([x["logits"] for x in outputs]) preds = torch.argmax(logits, 1) wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())}) self.logger.experiment.log( {"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())} ) def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])