File size: 6,077 Bytes
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63576f
 
 
da89f1c
f292cd1
da89f1c
 
 
 
 
 
 
8e3d6fe
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f292cd1
 
da89f1c
 
a63576f
 
da89f1c
f292cd1
a63576f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da89f1c
 
db7bdc3
da89f1c
 
 
 
 
 
 
b41635a
 
da89f1c
 
 
 
 
 
 
 
 
 
 
f292cd1
 
da89f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f292cd1
 
6cf4c1f
da89f1c
 
 
 
9d1fb84
da89f1c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import os
import logging
import torch
import random
import numpy as np
from model import DocBERT
from dataset import load_data, create_data_loaders
from trainer import Trainer

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

def set_seed(seed):
    """Set all seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def main():
    parser = argparse.ArgumentParser(description="Train a document classification model with BERT")
    
    # Data arguments
    parser.add_argument("--train_data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
    parser.add_argument("--val_data_path", type=str, required=True, help="Path to the validation dataset file (CSV or TSV)")
    parser.add_argument("--test_data_path", type=str, required=True, help="Path to the test dataset file (CSV or TSV)")
    parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
    parser.add_argument("--label_column", type=str, nargs="+", help="Name of the label column")
    parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio")
    parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio")
    
    # Model arguments
    parser.add_argument("--bert_model", type=str, default="bert-base-uncased", 
                        help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)")
    parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
    parser.add_argument("--max_length", type=int, default=250, help="Maximum sequence length (PhoBERT has 258 max_position_embeddings so we choose 250)")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability")
    
    # Training arguments
    parser.add_argument("--batch_size", type=int, default=16, help="Training batch size")
    parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization")
    parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps")
    parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Proportion of training for LR warmup")
    
    # Other arguments
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the model and logs")
    
    args = parser.parse_args()
    
    # Set seed for reproducibility
    set_seed(args.seed)
    
    # Create output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    # Log args for debugging
    logger.info(f"Running with arguments: {args}")
    
    num_categories = len(args.label_column) if isinstance(args.label_column, list) else 1
    label_column = args.label_column[0] if isinstance(args.label_column, list) and len(args.label_column) == 1 else args.label_column
    # Load and prepare data
    logger.info("Loading and preparing data...")
    train_data, _, _ = load_data(
        args.train_data_path,
        text_col=args.text_column,
        label_col=label_column,
        validation_split=0.0,
        test_split=0.0,
        seed=args.seed
    )

    _, val_data, _ = load_data(
        args.val_data_path,
        text_col=args.text_column,
        label_col=label_column,
        validation_split=1.0,
        test_split=0.0,
        seed=args.seed
    )

    _, _, test_data = load_data(
        args.test_data_path,
        text_col=args.text_column,
        label_col=label_column,
        validation_split=0.0,
        test_split=1.0,
        seed=args.seed
    )

    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(
        train_data, 
        val_data, 
        test_data,
        tokenizer_name=args.bert_model,
        max_length=args.max_length,
        batch_size=args.batch_size,
        num_classes=args.num_classes
    )
    
    logger.info(f"Train samples: {len(train_data[0])}, "
               f"Validation samples: {len(val_data[0])}, "
               f"Test samples: {len(test_data[0])}")
    
    # Initialize model
    logger.info(f"Initializing DocBERT model with {args.bert_model}...")
    model = DocBERT(
        num_classes=args.num_classes,
        bert_model_name=args.bert_model,
        dropout_prob=args.dropout,
        num_categories=num_categories
    )
    
    # Count and log model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_proportion=args.warmup_proportion,
        gradient_accumulation_steps=args.grad_accum_steps,
        num_categories=num_categories,
        num_classes=args.num_classes,
    )
    
    # Train the model
    logger.info("Starting training...")
    save_path = os.path.join(args.output_dir, args.bert_model.replace("/", "_") + "_finetuned.pth")
    trainer.train(epochs=args.epochs, save_path=save_path)
    
    logger.info("Training completed!")

if __name__ == "__main__":
    main()