Gagan Bhatia commited on
Commit
0bc3261
·
1 Parent(s): 0015a3c

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +14 -14
src/models/model.py CHANGED
@@ -296,20 +296,20 @@ class Summarization:
296
  )
297
 
298
  def train(
299
- self,
300
- train_df: pd.DataFrame,
301
- eval_df: pd.DataFrame,
302
- source_max_token_len: int = 512,
303
- target_max_token_len: int = 512,
304
- batch_size: int = 8,
305
- max_epochs: int = 5,
306
- use_gpu: bool = True,
307
- outputdir: str = "models",
308
- early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
309
- learning_rate: float = 0.0001,
310
- adam_epsilon: float = 0.01,
311
- num_workers: int = 2,
312
- weight_decay: float = 0.0001
313
  ):
314
  """
315
  trains T5/MT5 model on custom dataset
 
296
  )
297
 
298
  def train(
299
+ self,
300
+ train_df: pd.DataFrame,
301
+ eval_df: pd.DataFrame,
302
+ source_max_token_len: int = 512,
303
+ target_max_token_len: int = 512,
304
+ batch_size: int = 8,
305
+ max_epochs: int = 5,
306
+ use_gpu: bool = True,
307
+ outputdir: str = "models",
308
+ early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
309
+ learning_rate: float = 0.0001,
310
+ adam_epsilon: float = 0.01,
311
+ num_workers: int = 2,
312
+ weight_decay: float = 0.0001,
313
  ):
314
  """
315
  trains T5/MT5 model on custom dataset