Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
·
056c147
1
Parent(s):
6e46269
Update model.py
Browse files- src/models/model.py +17 -0
src/models/model.py
CHANGED
|
@@ -168,3 +168,20 @@ class LightningModel(LightningModule):
|
|
| 168 |
output = self.model(
|
| 169 |
input_ids,
|
| 170 |
attention_mask=attention_mask,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
output = self.model(
|
| 169 |
input_ids,
|
| 170 |
attention_mask=attention_mask,
|
| 171 |
+
labels=labels,
|
| 172 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return output.loss, output.logits
|
| 176 |
+
|
| 177 |
+
def training_step(self, batch, batch_size):
|
| 178 |
+
""" training step """
|
| 179 |
+
input_ids = batch["keywords_input_ids"]
|
| 180 |
+
attention_mask = batch["keywords_attention_mask"]
|
| 181 |
+
labels = batch["labels"]
|
| 182 |
+
labels_attention_mask = batch["labels_attention_mask"]
|
| 183 |
+
|
| 184 |
+
loss, outputs = self(
|
| 185 |
+
input_ids=input_ids,
|
| 186 |
+
attention_mask=attention_mask,
|
| 187 |
+
decoder_attention_mask=labels_attention_mask,
|