Spaces:
Runtime error
Runtime error
File size: 1,227 Bytes
adeafa0 f8a2915 cdce7a5 dfda3c6 cb1227c 297c713 cdce7a5 d51b694 e057e5a 5216253 d51b694 fb49f49 0fe8bf6 5cb4c53 d51b694 5222a47 7056372 e3ba5c5 3c8cb17 bd81828 7ce22c0 7547250 f2e152e 293f3ab 7f5fe96 51b4b2d 45e8538 7b5efb9 05f29e7 7f5fe96 b2f69cf 7f5fe96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import json
import yaml
from model import Summarization
import pandas as pd
def train_model():
"""
Train the model
"""
with open("params.yml") as f:
params = yaml.safe_load(f)
# Load the data
train_df = pd.read_csv("data/processed/train.csv")
eval_df = pd.read_csv("data/processed/validation.csv")
train_df = train_df.sample(frac=params["split"], replace=True, random_state=1)
eval_df = eval_df.sample(frac=params["split"], replace=True, random_state=1)
model = Summarization()
model.from_pretrained(
model_type=params["model_type"], model_name=params["model_name"]
)
model.train(
train_df=train_df,
eval_df=eval_df,
batch_size=params["batch_size"],
max_epochs=params["epochs"],
use_gpu=params["use_gpu"],
learning_rate=float(params["learning_rate"]),
num_workers=int(params["num_workers"]),
)
model.save_model(model_dir=params["model_dir"])
with open("wandb/latest-run/files/wandb-summary.json") as json_file:
data = json.load(json_file)
with open("reports/training_metrics.txt", "w") as fp:
json.dump(data, fp)
if __name__ == "__main__":
train_model()
|