Paarth commited on
Commit
662f2cf
1 Parent(s): 011c557

Upload summarizer.py

Browse files
Files changed (1) hide show
  1. summarizer.py +74 -0
summarizer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from torch import nn
5
+ from transformers import AdamW
6
+ from transformers import T5ForConditionalGeneration
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from pytorch_lightning.loggers import TensorBoardLogger
9
+
10
+ class SummarizerModel(pl.LightningModule):
11
+ def __init__(self, model_name = None):
12
+ super().__init__()
13
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name, return_dict = True)
14
+
15
+ def forward(self,
16
+ input_ids,
17
+ attention_mask,
18
+ decoder_attention_mask,
19
+ labels = None):
20
+ output = self.model(
21
+ input_ids,
22
+ attention_mask = attention_mask,
23
+ labels = labels,
24
+ decoder_attention_mask = decoder_attention_mask
25
+ )
26
+ return output.loss, output.logits
27
+
28
+ def training_step(self, batch, batch_idx):
29
+ input_ids = batch['text_input_ids']
30
+ attention_mask = batch['text_attention_mask']
31
+ labels = batch['labels']
32
+ decoder_attention_mask = batch['labels_attention_mask']
33
+
34
+ loss, outputs = self.forward(
35
+ input_ids = input_ids,
36
+ attention_mask = attention_mask,
37
+ decoder_attention_mask = decoder_attention_mask,
38
+ labels = labels
39
+ )
40
+ self.log("train_loss", loss, prog_bar = True, logger = True)
41
+ return loss
42
+
43
+ def validation_step(self, batch, batch_idx):
44
+ input_ids = batch['text_input_ids']
45
+ attention_mask = batch['text_attention_mask']
46
+ labels = batch['labels']
47
+ decoder_attention_mask = batch['labels_attention_mask']
48
+
49
+ loss, outputs = self.forward(
50
+ input_ids = input_ids,
51
+ attention_mask = attention_mask,
52
+ decoder_attention_mask = decoder_attention_mask,
53
+ labels = labels
54
+ )
55
+ self.log("val_loss", loss, prog_bar = True, logger = True)
56
+ return loss
57
+
58
+ def test_step(self, batch, batch_idx):
59
+ input_ids = batch['text_input_ids']
60
+ attention_mask = batch['text_attention_mask']
61
+ labels = batch['labels']
62
+ decoder_attention_mask = batch['labels_attention_mask']
63
+
64
+ loss, outputs = self.forward(
65
+ input_ids = input_ids,
66
+ attention_mask = attention_mask,
67
+ decoder_attention_mask = decoder_attention_mask,
68
+ labels = labels
69
+ )
70
+ self.log("test_loss", loss, prog_bar = True, logger = True)
71
+ return loss
72
+
73
+ def configure_optimizers(self):
74
+ return AdamW(self.model.parameters(), lr = 0.0001)