File size: 1,227 Bytes
adeafa0
f8a2915
cdce7a5
 
dfda3c6
cb1227c
 
297c713
 
 
 
 
cdce7a5
 
 
d51b694
e057e5a
5216253
d51b694
fb49f49
0fe8bf6
5cb4c53
d51b694
5222a47
7056372
e3ba5c5
3c8cb17
bd81828
 
 
 
7ce22c0
 
 
7547250
 
f2e152e
293f3ab
7f5fe96
51b4b2d
45e8538
 
7b5efb9
05f29e7
 
7f5fe96
b2f69cf
7f5fe96
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json

import yaml

from model import Summarization
import pandas as pd


def train_model():
    """
    Train the model
    """
    with open("params.yml") as f:
        params = yaml.safe_load(f)

    # Load the data
    train_df = pd.read_csv("data/processed/train.csv")
    eval_df = pd.read_csv("data/processed/validation.csv")

    train_df = train_df.sample(frac=params["split"], replace=True, random_state=1)
    eval_df = eval_df.sample(frac=params["split"], replace=True, random_state=1)

    model = Summarization()
    model.from_pretrained(
        model_type=params["model_type"], model_name=params["model_name"]
    )

    model.train(
        train_df=train_df,
        eval_df=eval_df,
        batch_size=params["batch_size"],
        max_epochs=params["epochs"],
        use_gpu=params["use_gpu"],
        learning_rate=float(params["learning_rate"]),
        num_workers=int(params["num_workers"]),
    )

    model.save_model(model_dir=params["model_dir"])

    with open("wandb/latest-run/files/wandb-summary.json") as json_file:
        data = json.load(json_file)

    with open("reports/training_metrics.txt", "w") as fp:
        json.dump(data, fp)


if __name__ == "__main__":
    train_model()