import argparse import os import logging import torch import random import numpy as np from model import DocBERT from dataset import load_data, create_data_loaders from trainer import Trainer # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) def set_seed(seed): """Set all seeds for reproducibility""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def main(): parser = argparse.ArgumentParser(description="Train a document classification model with BERT") # Data arguments parser.add_argument("--train_data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)") parser.add_argument("--val_data_path", type=str, required=True, help="Path to the validation dataset file (CSV or TSV)") parser.add_argument("--test_data_path", type=str, required=True, help="Path to the test dataset file (CSV or TSV)") parser.add_argument("--text_column", type=str, default="text", help="Name of the text column") parser.add_argument("--label_column", type=str, nargs="+", help="Name of the label column") parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio") parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio") # Model arguments parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)") parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict") parser.add_argument("--max_length", type=int, default=250, help="Maximum sequence length (PhoBERT has 258 max_position_embeddings so we choose 250)") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability") # Training arguments parser.add_argument("--batch_size", type=int, default=16, help="Training batch size") parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization") parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Proportion of training for LR warmup") # Other arguments parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the model and logs") args = parser.parse_args() # Set seed for reproducibility set_seed(args.seed) # Create output directory if it doesn't exist if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # Log args for debugging logger.info(f"Running with arguments: {args}") num_categories = len(args.label_column) if isinstance(args.label_column, list) else 1 label_column = args.label_column[0] if isinstance(args.label_column, list) and len(args.label_column) == 1 else args.label_column # Load and prepare data logger.info("Loading and preparing data...") train_data, _, _ = load_data( args.train_data_path, text_col=args.text_column, label_col=label_column, validation_split=0.0, test_split=0.0, seed=args.seed ) _, val_data, _ = load_data( args.val_data_path, text_col=args.text_column, label_col=label_column, validation_split=1.0, test_split=0.0, seed=args.seed ) _, _, test_data = load_data( args.test_data_path, text_col=args.text_column, label_col=label_column, validation_split=0.0, test_split=1.0, seed=args.seed ) # Create data loaders train_loader, val_loader, test_loader = create_data_loaders( train_data, val_data, test_data, tokenizer_name=args.bert_model, max_length=args.max_length, batch_size=args.batch_size, num_classes=args.num_classes ) logger.info(f"Train samples: {len(train_data[0])}, " f"Validation samples: {len(val_data[0])}, " f"Test samples: {len(test_data[0])}") # Initialize model logger.info(f"Initializing DocBERT model with {args.bert_model}...") model = DocBERT( num_classes=args.num_classes, bert_model_name=args.bert_model, dropout_prob=args.dropout, num_categories=num_categories ) # Count and log model parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Total parameters: {total_params:,}") logger.info(f"Trainable parameters: {trainable_params:,}") # Initialize trainer trainer = Trainer( model=model, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, lr=args.learning_rate, weight_decay=args.weight_decay, warmup_proportion=args.warmup_proportion, gradient_accumulation_steps=args.grad_accum_steps, num_categories=num_categories, num_classes=args.num_classes, ) # Train the model logger.info("Starting training...") save_path = os.path.join(args.output_dir, args.bert_model.replace("/", "_") + "_finetuned.pth") trainer.train(epochs=args.epochs, save_path=save_path) logger.info("Training completed!") if __name__ == "__main__": main()