Gagan Bhatia commited on
Commit
3315012
1 Parent(s): 9e7288d

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +9 -9
src/models/model.py CHANGED
@@ -89,15 +89,15 @@ class DataModule(Dataset):
89
 
90
  class PLDataModule(LightningDataModule):
91
  def __init__(
92
- self,
93
- train_df: pd.DataFrame,
94
- test_df: pd.DataFrame,
95
- tokenizer: T5Tokenizer,
96
- source_max_token_len: int = 512,
97
- target_max_token_len: int = 512,
98
- batch_size: int = 4,
99
- split: float = 0.1,
100
- num_workers: int = 2
101
  ):
102
  """
103
  :param data_df:
 
89
 
90
  class PLDataModule(LightningDataModule):
91
  def __init__(
92
+ self,
93
+ train_df: pd.DataFrame,
94
+ test_df: pd.DataFrame,
95
+ tokenizer: T5Tokenizer,
96
+ source_max_token_len: int = 512,
97
+ target_max_token_len: int = 512,
98
+ batch_size: int = 4,
99
+ split: float = 0.1,
100
+ num_workers: int = 2,
101
  ):
102
  """
103
  :param data_df: