gagan3012 commited on
Commit
698a370
1 Parent(s): 90610ad

split added

Browse files
Files changed (2) hide show
  1. params.yml +3 -2
  2. src/models/train_model.py +3 -0
params.yml CHANGED
@@ -2,9 +2,10 @@ data: cnn_dailymail
2
  batch_size: 4
3
  num_workers: 2
4
  model_type: t5
5
- model_name: t5-base
6
  learning_rate: 1e-4
7
  epochs: 5
8
  source_dir: src
9
  model_dir: models
10
- metric: rouge
 
 
2
  batch_size: 4
3
  num_workers: 2
4
  model_type: t5
5
+ model_name: t5-small
6
  learning_rate: 1e-4
7
  epochs: 5
8
  source_dir: src
9
  model_dir: models
10
+ metric: rouge
11
+ split: 0.001
src/models/train_model.py CHANGED
@@ -15,6 +15,9 @@ def train_model():
15
  train_df = pd.read_csv('data/processed/train.csv')
16
  eval_df = pd.read_csv('data/processed/validation.csv')
17
 
 
 
 
18
  model = Summarization()
19
  model.from_pretrained(model_type=params['model_type'], model_name=params['model_name'])
20
 
 
15
  train_df = pd.read_csv('data/processed/train.csv')
16
  eval_df = pd.read_csv('data/processed/validation.csv')
17
 
18
+ train_df = train_df.sample(frac=params['split'], replace=True, random_state=1)
19
+ eval_df = eval_df.sample(frac=params['split'], replace=True, random_state=1)
20
+
21
  model = Summarization()
22
  model.from_pretrained(model_type=params['model_type'], model_name=params['model_name'])
23