Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
·
056c147
1
Parent(s):
6e46269
Update model.py
Browse files- src/models/model.py +17 -0
src/models/model.py
CHANGED
@@ -168,3 +168,20 @@ class LightningModel(LightningModule):
|
|
168 |
output = self.model(
|
169 |
input_ids,
|
170 |
attention_mask=attention_mask,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
output = self.model(
|
169 |
input_ids,
|
170 |
attention_mask=attention_mask,
|
171 |
+
labels=labels,
|
172 |
+
decoder_attention_mask=decoder_attention_mask,
|
173 |
+
)
|
174 |
+
|
175 |
+
return output.loss, output.logits
|
176 |
+
|
177 |
+
def training_step(self, batch, batch_size):
|
178 |
+
""" training step """
|
179 |
+
input_ids = batch["keywords_input_ids"]
|
180 |
+
attention_mask = batch["keywords_attention_mask"]
|
181 |
+
labels = batch["labels"]
|
182 |
+
labels_attention_mask = batch["labels_attention_mask"]
|
183 |
+
|
184 |
+
loss, outputs = self(
|
185 |
+
input_ids=input_ids,
|
186 |
+
attention_mask=attention_mask,
|
187 |
+
decoder_attention_mask=labels_attention_mask,
|