| from argparse import ArgumentParser | |
| from util import ConfigParser | |
| from logger import Wandb | |
| from trainer import Trainer | |
| from dataset import Dataset | |
| from tokenizer import Tokenizer | |
| from model import Model | |
| from logger import Wandb | |
| from export import ExportAll | |
| parser = ArgumentParser( | |
| prog='Trainer implementation, using Pytorch', | |
| description='' | |
| ) | |
| if __name__ == '__main__': | |
| parser.add_argument('-p', '--config_path') | |
| args = parser.parse_args() | |
| config = ConfigParser(args.config_path).config | |
| dataset = Dataset(config.dataset) | |
| tokenizer = Tokenizer() | |
| tokenizer.train(dataset.text, max_length=config.tokenizer.max_length) | |
| ids = tokenizer.c_encode(dataset.text) | |
| config.model.tokenizer = tokenizer | |
| config.model.params.vocab_size = tokenizer.vocab_size | |
| batches, num_batches = dataset.batch(ids) | |
| config.trainer.num_batches = num_batches | |
| model = Model(config.model) | |
| wandb = Wandb(config.wandb) | |
| config.trainer.model = model | |
| config.trainer.wandb = wandb | |
| trainer = Trainer(config.trainer) | |
| trainer.train(batches) | |
| model.save_pretrained() | |
| tokenizer.to_file('tokenizer.bin') | |
| ExportAll() |