File size: 6,077 Bytes
da89f1c a63576f da89f1c f292cd1 da89f1c 8e3d6fe da89f1c f292cd1 da89f1c a63576f da89f1c f292cd1 a63576f da89f1c db7bdc3 da89f1c b41635a da89f1c f292cd1 da89f1c f292cd1 6cf4c1f da89f1c 9d1fb84 da89f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import argparse
import os
import logging
import torch
import random
import numpy as np
from model import DocBERT
from dataset import load_data, create_data_loaders
from trainer import Trainer
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def set_seed(seed):
"""Set all seeds for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
parser = argparse.ArgumentParser(description="Train a document classification model with BERT")
# Data arguments
parser.add_argument("--train_data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)")
parser.add_argument("--val_data_path", type=str, required=True, help="Path to the validation dataset file (CSV or TSV)")
parser.add_argument("--test_data_path", type=str, required=True, help="Path to the test dataset file (CSV or TSV)")
parser.add_argument("--text_column", type=str, default="text", help="Name of the text column")
parser.add_argument("--label_column", type=str, nargs="+", help="Name of the label column")
parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio")
parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio")
# Model arguments
parser.add_argument("--bert_model", type=str, default="bert-base-uncased",
help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)")
parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
parser.add_argument("--max_length", type=int, default=250, help="Maximum sequence length (PhoBERT has 258 max_position_embeddings so we choose 250)")
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability")
# Training arguments
parser.add_argument("--batch_size", type=int, default=16, help="Training batch size")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization")
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Proportion of training for LR warmup")
# Other arguments
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the model and logs")
args = parser.parse_args()
# Set seed for reproducibility
set_seed(args.seed)
# Create output directory if it doesn't exist
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# Log args for debugging
logger.info(f"Running with arguments: {args}")
num_categories = len(args.label_column) if isinstance(args.label_column, list) else 1
label_column = args.label_column[0] if isinstance(args.label_column, list) and len(args.label_column) == 1 else args.label_column
# Load and prepare data
logger.info("Loading and preparing data...")
train_data, _, _ = load_data(
args.train_data_path,
text_col=args.text_column,
label_col=label_column,
validation_split=0.0,
test_split=0.0,
seed=args.seed
)
_, val_data, _ = load_data(
args.val_data_path,
text_col=args.text_column,
label_col=label_column,
validation_split=1.0,
test_split=0.0,
seed=args.seed
)
_, _, test_data = load_data(
args.test_data_path,
text_col=args.text_column,
label_col=label_column,
validation_split=0.0,
test_split=1.0,
seed=args.seed
)
# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
train_data,
val_data,
test_data,
tokenizer_name=args.bert_model,
max_length=args.max_length,
batch_size=args.batch_size,
num_classes=args.num_classes
)
logger.info(f"Train samples: {len(train_data[0])}, "
f"Validation samples: {len(val_data[0])}, "
f"Test samples: {len(test_data[0])}")
# Initialize model
logger.info(f"Initializing DocBERT model with {args.bert_model}...")
model = DocBERT(
num_classes=args.num_classes,
bert_model_name=args.bert_model,
dropout_prob=args.dropout,
num_categories=num_categories
)
# Count and log model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
# Initialize trainer
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
lr=args.learning_rate,
weight_decay=args.weight_decay,
warmup_proportion=args.warmup_proportion,
gradient_accumulation_steps=args.grad_accum_steps,
num_categories=num_categories,
num_classes=args.num_classes,
)
# Train the model
logger.info("Starting training...")
save_path = os.path.join(args.output_dir, args.bert_model.replace("/", "_") + "_finetuned.pth")
trainer.train(epochs=args.epochs, save_path=save_path)
logger.info("Training completed!")
if __name__ == "__main__":
main() |