Gagan Bhatia commited on
Commit
9ed09a4
1 Parent(s): ae4f8b6

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +9 -0
src/models/model.py CHANGED
@@ -141,3 +141,12 @@ class PLDataModule(LightningDataModule):
141
  def val_dataloader(self):
142
  """ validation dataloader """
143
  return DataLoader(
 
 
 
 
 
 
 
 
 
 
141
  def val_dataloader(self):
142
  """ validation dataloader """
143
  return DataLoader(
144
+ self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
145
+ )
146
+
147
+
148
+ class LightningModel(LightningModule):
149
+ """ PyTorch Lightning Model class"""
150
+
151
+ def __init__(self, tokenizer, model, output: str = "outputs"):
152
+ """