Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
•
0015a3c
1
Parent(s):
43e0847
Update model.py
Browse files- src/models/model.py +6 -6
src/models/model.py
CHANGED
@@ -160,7 +160,7 @@ class PLDataModule(LightningDataModule):
|
|
160 |
|
161 |
|
162 |
class LightningModel(LightningModule):
|
163 |
-
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
@@ -187,7 +187,7 @@ class LightningModel(LightningModule):
|
|
187 |
self.weight_decay = weight_decay
|
188 |
|
189 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
190 |
-
"""
|
191 |
output = self.model(
|
192 |
input_ids,
|
193 |
attention_mask=attention_mask,
|
@@ -198,7 +198,7 @@ class LightningModel(LightningModule):
|
|
198 |
return output.loss, output.logits
|
199 |
|
200 |
def training_step(self, batch, batch_size):
|
201 |
-
"""
|
202 |
input_ids = batch["keywords_input_ids"]
|
203 |
attention_mask = batch["keywords_attention_mask"]
|
204 |
labels = batch["labels"]
|
@@ -214,7 +214,7 @@ class LightningModel(LightningModule):
|
|
214 |
return loss
|
215 |
|
216 |
def validation_step(self, batch, batch_size):
|
217 |
-
"""
|
218 |
input_ids = batch["keywords_input_ids"]
|
219 |
attention_mask = batch["keywords_attention_mask"]
|
220 |
labels = batch["labels"]
|
@@ -230,7 +230,7 @@ class LightningModel(LightningModule):
|
|
230 |
return loss
|
231 |
|
232 |
def test_step(self, batch, batch_size):
|
233 |
-
"""
|
234 |
input_ids = batch["keywords_input_ids"]
|
235 |
attention_mask = batch["keywords_attention_mask"]
|
236 |
labels = batch["labels"]
|
@@ -247,7 +247,7 @@ class LightningModel(LightningModule):
|
|
247 |
return loss
|
248 |
|
249 |
def configure_optimizers(self):
|
250 |
-
"""
|
251 |
model = self.model
|
252 |
no_decay = ["bias", "LayerNorm.weight"]
|
253 |
optimizer_grouped_parameters = [
|
|
|
160 |
|
161 |
|
162 |
class LightningModel(LightningModule):
|
163 |
+
"""PyTorch Lightning Model class"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
|
|
187 |
self.weight_decay = weight_decay
|
188 |
|
189 |
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
190 |
+
"""forward step"""
|
191 |
output = self.model(
|
192 |
input_ids,
|
193 |
attention_mask=attention_mask,
|
|
|
198 |
return output.loss, output.logits
|
199 |
|
200 |
def training_step(self, batch, batch_size):
|
201 |
+
"""training step"""
|
202 |
input_ids = batch["keywords_input_ids"]
|
203 |
attention_mask = batch["keywords_attention_mask"]
|
204 |
labels = batch["labels"]
|
|
|
214 |
return loss
|
215 |
|
216 |
def validation_step(self, batch, batch_size):
|
217 |
+
"""validation step"""
|
218 |
input_ids = batch["keywords_input_ids"]
|
219 |
attention_mask = batch["keywords_attention_mask"]
|
220 |
labels = batch["labels"]
|
|
|
230 |
return loss
|
231 |
|
232 |
def test_step(self, batch, batch_size):
|
233 |
+
"""test step"""
|
234 |
input_ids = batch["keywords_input_ids"]
|
235 |
attention_mask = batch["keywords_attention_mask"]
|
236 |
labels = batch["labels"]
|
|
|
247 |
return loss
|
248 |
|
249 |
def configure_optimizers(self):
|
250 |
+
"""configure optimizers"""
|
251 |
model = self.model
|
252 |
no_decay = ["bias", "LayerNorm.weight"]
|
253 |
optimizer_grouped_parameters = [
|