import torch

from train import Trainer
from config import get_config
from lanet_utils import prepare_dirs
from data_loader import get_data_loader


def main(config):
    # ensure directories are setup
    prepare_dirs(config)

    # ensure reproducibility
    torch.manual_seed(config.seed)
    if config.use_gpu:
        torch.cuda.manual_seed(config.seed)

    # instantiate train data loaders
    train_loader = get_data_loader(config=config)

    trainer = Trainer(config, train_loader=train_loader)
    trainer.train()


if __name__ == "__main__":
    config, unparsed = get_config()
    main(config)