Gagan Bhatia commited on
Commit
0015a3c
1 Parent(s): 43e0847

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +6 -6
src/models/model.py CHANGED
@@ -160,7 +160,7 @@ class PLDataModule(LightningDataModule):
160
 
161
 
162
  class LightningModel(LightningModule):
163
- """ PyTorch Lightning Model class"""
164
 
165
  def __init__(
166
  self,
@@ -187,7 +187,7 @@ class LightningModel(LightningModule):
187
  self.weight_decay = weight_decay
188
 
189
  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
190
- """ forward step """
191
  output = self.model(
192
  input_ids,
193
  attention_mask=attention_mask,
@@ -198,7 +198,7 @@ class LightningModel(LightningModule):
198
  return output.loss, output.logits
199
 
200
  def training_step(self, batch, batch_size):
201
- """ training step """
202
  input_ids = batch["keywords_input_ids"]
203
  attention_mask = batch["keywords_attention_mask"]
204
  labels = batch["labels"]
@@ -214,7 +214,7 @@ class LightningModel(LightningModule):
214
  return loss
215
 
216
  def validation_step(self, batch, batch_size):
217
- """ validation step """
218
  input_ids = batch["keywords_input_ids"]
219
  attention_mask = batch["keywords_attention_mask"]
220
  labels = batch["labels"]
@@ -230,7 +230,7 @@ class LightningModel(LightningModule):
230
  return loss
231
 
232
  def test_step(self, batch, batch_size):
233
- """ test step """
234
  input_ids = batch["keywords_input_ids"]
235
  attention_mask = batch["keywords_attention_mask"]
236
  labels = batch["labels"]
@@ -247,7 +247,7 @@ class LightningModel(LightningModule):
247
  return loss
248
 
249
  def configure_optimizers(self):
250
- """ configure optimizers """
251
  model = self.model
252
  no_decay = ["bias", "LayerNorm.weight"]
253
  optimizer_grouped_parameters = [
 
160
 
161
 
162
  class LightningModel(LightningModule):
163
+ """PyTorch Lightning Model class"""
164
 
165
  def __init__(
166
  self,
 
187
  self.weight_decay = weight_decay
188
 
189
  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
190
+ """forward step"""
191
  output = self.model(
192
  input_ids,
193
  attention_mask=attention_mask,
 
198
  return output.loss, output.logits
199
 
200
  def training_step(self, batch, batch_size):
201
+ """training step"""
202
  input_ids = batch["keywords_input_ids"]
203
  attention_mask = batch["keywords_attention_mask"]
204
  labels = batch["labels"]
 
214
  return loss
215
 
216
  def validation_step(self, batch, batch_size):
217
+ """validation step"""
218
  input_ids = batch["keywords_input_ids"]
219
  attention_mask = batch["keywords_attention_mask"]
220
  labels = batch["labels"]
 
230
  return loss
231
 
232
  def test_step(self, batch, batch_size):
233
+ """test step"""
234
  input_ids = batch["keywords_input_ids"]
235
  attention_mask = batch["keywords_attention_mask"]
236
  labels = batch["labels"]
 
247
  return loss
248
 
249
  def configure_optimizers(self):
250
+ """configure optimizers"""
251
  model = self.model
252
  no_decay = ["bias", "LayerNorm.weight"]
253
  optimizer_grouped_parameters = [