""" Nanotron training script. Usage: ``` export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml ``` """ import argparse from nanotron.trainer import DistributedTrainer from dataloader import get_dataloader from modeling_mistral import MistralForTraining from config_tiny_mistral import MistralConfig def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") return parser.parse_args() if __name__ == "__main__": args = get_args() config_file = args.config_file # Load trainer and data trainer = DistributedTrainer(config_file, model_class=MistralForTraining, model_config_class=MistralConfig) dataloader = get_dataloader(trainer) # Train trainer.train(dataloader)