Spaces:
Runtime error
Runtime error
import yaml | |
from src.models.model import Summarization | |
import pandas as pd | |
def train_model(): | |
""" | |
Train the model | |
""" | |
with open("params.yml") as f: | |
params = yaml.safe_load(f) | |
# Load the data | |
train_df = pd.read_csv('data/processed/train.csv') | |
eval_df = pd.read_csv('data/processed/validation.csv') | |
model = Summarization() | |
model.from_pretrained(model_type=params['model_type'], model_name=params['model_name']) | |
model.train(train_df=train_df, eval_df=eval_df, | |
batch_size=params['batch_size'], max_epochs=params['max_epoch'], | |
use_gpu=params['use_gpu'], learning_rate=params['learning_rate'], | |
num_workers=params['num_workers']) | |
model.save_model(model_dir=params['model_dir']) | |
if __name__ == '__main__': | |
train_model() | |