from happytransformer import HappyTextToText, TTTrainArgs | |
happy_tt = HappyTextToText("T5", load_path=".") | |
args = TTTrainArgs(batch_size=8) | |
happy_tt.train("data/train.csv", args=args) | |
before_loss = happy_tt.eval("data/eval.csv") | |
print("After loss: ", before_loss.loss) | |
happy_tt.save('.') | |