File size: 401 Bytes
f4fac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import fire

from config import  TrainConfig, T5ModelConfig
from model.trainer import ChatTrainer


if __name__ == '__main__':
    train_config = TrainConfig()
    model_config = T5ModelConfig()

    chat_trainer = ChatTrainer(train_config=train_config, model_config=model_config)

    # 解析命令行参数,执行指定函数
    # e.g: python train.py train
    fire.Fire(component=chat_trainer)