Gagan Bhatia commited on
Commit
d51b694
1 Parent(s): 297c713

Update train_model.py

Browse files
Files changed (1) hide show
  1. src/models/train_model.py +8 -0
src/models/train_model.py CHANGED
@@ -5,3 +5,11 @@ def train_model():
5
  """
6
  Train the model
7
  """
 
 
 
 
 
 
 
 
 
5
  """
6
  Train the model
7
  """
8
+ # Load the data
9
+ train_df = make_dataset(split = 'train')
10
+ eval_df = make_dataset(split = 'test')
11
+
12
+ model = Summarization()
13
+ model.from_pretrained('t5-base')
14
+ model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
15
+ model.save_model()