iakarshu commited on
Commit
350be94
1 Parent(s): ed39c5e

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -122
utils.py DELETED
@@ -1,122 +0,0 @@
1
- import torch.nn as nn
2
- from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor
3
-
4
- class DocFormerForClassification(nn.Module):
5
-
6
- def __init__(self, config):
7
- super(DocFormerForClassification, self).__init__()
8
-
9
- self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
10
- self.embeddings = DocFormerEmbeddings(config)
11
- self.lang_emb = LanguageFeatureExtractor()
12
- self.config = config
13
- self.dropout = nn.Dropout(config['hidden_dropout_prob'])
14
- self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = 16) ## Number of Classes
15
- self.encoder = DocFormerEncoder(config)
16
-
17
- def forward(self, batch_dict):
18
-
19
- x_feat = batch_dict['x_features']
20
- y_feat = batch_dict['y_features']
21
-
22
- token = batch_dict['input_ids']
23
- img = batch_dict['resized_scaled_img']
24
-
25
- v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)
26
- v_bar = self.resnet(img)
27
- t_bar = self.lang_emb(token)
28
- out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)
29
- out = self.linear_layer(out)
30
- out = out[:, 0, :]
31
- return out
32
-
33
-
34
- ## Defining pytorch lightning model
35
- import pytorch_lightning as pl
36
- from sklearn.metrics import accuracy_score, confusion_matrix
37
- import pandas as pd
38
- import matplotlib.pyplot as plt
39
- import seaborn as sns
40
- import numpy as np
41
- import torchmetrics
42
- import wandb
43
- import torch
44
-
45
- class DocFormer(pl.LightningModule):
46
-
47
- def __init__(self, config , lr = 5e-5):
48
- super(DocFormer, self).__init__()
49
-
50
- self.save_hyperparameters()
51
- self.config = config
52
- self.docformer = DocFormerForClassification(config)
53
-
54
- self.num_classes = 16
55
- self.train_accuracy_metric = torchmetrics.Accuracy()
56
- self.val_accuracy_metric = torchmetrics.Accuracy()
57
- self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)
58
- self.precision_macro_metric = torchmetrics.Precision(
59
- average="macro", num_classes=self.num_classes
60
- )
61
- self.recall_macro_metric = torchmetrics.Recall(
62
- average="macro", num_classes=self.num_classes
63
- )
64
- self.precision_micro_metric = torchmetrics.Precision(average="micro")
65
- self.recall_micro_metric = torchmetrics.Recall(average="micro")
66
-
67
- def forward(self, batch_dict):
68
- logits = self.docformer(batch_dict)
69
- return logits
70
-
71
- def training_step(self, batch, batch_idx):
72
- logits = self.forward(batch)
73
-
74
- loss = nn.CrossEntropyLoss()(logits, batch['label'])
75
- preds = torch.argmax(logits, 1)
76
-
77
- ## Calculating the accuracy score
78
- train_acc = self.train_accuracy_metric(preds, batch["label"])
79
-
80
- ## Logging
81
- self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)
82
- self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)
83
-
84
- return loss
85
-
86
- def validation_step(self, batch, batch_idx):
87
- logits = self.forward(batch)
88
- loss = nn.CrossEntropyLoss()(logits, batch['label'])
89
- preds = torch.argmax(logits, 1)
90
-
91
- labels = batch['label']
92
- # Metrics
93
- valid_acc = self.val_accuracy_metric(preds, labels)
94
- precision_macro = self.precision_macro_metric(preds, labels)
95
- recall_macro = self.recall_macro_metric(preds, labels)
96
- precision_micro = self.precision_micro_metric(preds, labels)
97
- recall_micro = self.recall_micro_metric(preds, labels)
98
- f1 = self.f1_metric(preds, labels)
99
-
100
- # Logging metrics
101
- self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True)
102
- self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)
103
- self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
104
- self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
105
- self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
106
- self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
107
- self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
108
-
109
- return {"label": batch['label'], "logits": logits}
110
-
111
- def validation_epoch_end(self, outputs):
112
- labels = torch.cat([x["label"] for x in outputs])
113
- logits = torch.cat([x["logits"] for x in outputs])
114
- preds = torch.argmax(logits, 1)
115
-
116
- wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})
117
- self.logger.experiment.log(
118
- {"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}
119
- )
120
-
121
- def configure_optimizers(self):
122
- return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])