Gagan Bhatia commited on
Commit
6e46269
1 Parent(s): 9ed09a4

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +18 -0
src/models/model.py CHANGED
@@ -150,3 +150,21 @@ class LightningModel(LightningModule):
150
 
151
  def __init__(self, tokenizer, model, output: str = "outputs"):
152
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  def __init__(self, tokenizer, model, output: str = "outputs"):
152
  """
153
+ initiates a PyTorch Lightning Model
154
+ Args:
155
+ tokenizer : T5 tokenizer
156
+ model : T5 model
157
+ output (str, optional): output directory to save model checkpoints. Defaults to "outputs".
158
+ """
159
+ super().__init__()
160
+ self.model = model
161
+ self.tokenizer = tokenizer
162
+ self.output = output
163
+ # self.val_acc = Accuracy()
164
+ # self.train_acc = Accuracy()
165
+
166
+ def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
167
+ """ forward step """
168
+ output = self.model(
169
+ input_ids,
170
+ attention_mask=attention_mask,