Gagan Bhatia commited on
Commit
056c147
·
1 Parent(s): 6e46269

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +17 -0
src/models/model.py CHANGED
@@ -168,3 +168,20 @@ class LightningModel(LightningModule):
168
  output = self.model(
169
  input_ids,
170
  attention_mask=attention_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  output = self.model(
169
  input_ids,
170
  attention_mask=attention_mask,
171
+ labels=labels,
172
+ decoder_attention_mask=decoder_attention_mask,
173
+ )
174
+
175
+ return output.loss, output.logits
176
+
177
+ def training_step(self, batch, batch_size):
178
+ """ training step """
179
+ input_ids = batch["keywords_input_ids"]
180
+ attention_mask = batch["keywords_attention_mask"]
181
+ labels = batch["labels"]
182
+ labels_attention_mask = batch["labels_attention_mask"]
183
+
184
+ loss, outputs = self(
185
+ input_ids=input_ids,
186
+ attention_mask=attention_mask,
187
+ decoder_attention_mask=labels_attention_mask,