spellcheck_model / train.py
Leonardo Yoshida
feat: add initial trained model and scripts
54f5ec9
raw
history blame
291 Bytes
from happytransformer import HappyTextToText, TTTrainArgs
happy_tt = HappyTextToText("T5", load_path=".")
args = TTTrainArgs(batch_size=8)
happy_tt.train("data/train.csv", args=args)
before_loss = happy_tt.eval("data/eval.csv")
print("After loss: ", before_loss.loss)
happy_tt.save('.')