Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
·
0bc3261
1
Parent(s):
0015a3c
Update model.py
Browse files- src/models/model.py +14 -14
src/models/model.py
CHANGED
@@ -296,20 +296,20 @@ class Summarization:
|
|
296 |
)
|
297 |
|
298 |
def train(
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
):
|
314 |
"""
|
315 |
trains T5/MT5 model on custom dataset
|
|
|
296 |
)
|
297 |
|
298 |
def train(
|
299 |
+
self,
|
300 |
+
train_df: pd.DataFrame,
|
301 |
+
eval_df: pd.DataFrame,
|
302 |
+
source_max_token_len: int = 512,
|
303 |
+
target_max_token_len: int = 512,
|
304 |
+
batch_size: int = 8,
|
305 |
+
max_epochs: int = 5,
|
306 |
+
use_gpu: bool = True,
|
307 |
+
outputdir: str = "models",
|
308 |
+
early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
|
309 |
+
learning_rate: float = 0.0001,
|
310 |
+
adam_epsilon: float = 0.01,
|
311 |
+
num_workers: int = 2,
|
312 |
+
weight_decay: float = 0.0001,
|
313 |
):
|
314 |
"""
|
315 |
trains T5/MT5 model on custom dataset
|