|
import argparse |
|
import os |
|
import logging |
|
import torch |
|
import random |
|
import numpy as np |
|
from model import DocBERT |
|
from dataset import load_data, create_data_loaders |
|
from trainer import Trainer |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def set_seed(seed): |
|
"""Set all seeds for reproducibility""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Train a document classification model with BERT") |
|
|
|
|
|
parser.add_argument("--train_data_path", type=str, required=True, help="Path to the dataset file (CSV or TSV)") |
|
parser.add_argument("--val_data_path", type=str, required=True, help="Path to the validation dataset file (CSV or TSV)") |
|
parser.add_argument("--test_data_path", type=str, required=True, help="Path to the test dataset file (CSV or TSV)") |
|
parser.add_argument("--text_column", type=str, default="text", help="Name of the text column") |
|
parser.add_argument("--label_column", type=str, nargs="+", help="Name of the label column") |
|
parser.add_argument("--val_split", type=float, default=0.1, help="Validation set split ratio") |
|
parser.add_argument("--test_split", type=float, default=0.1, help="Test set split ratio") |
|
|
|
|
|
parser.add_argument("--bert_model", type=str, default="bert-base-uncased", |
|
help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)") |
|
parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict") |
|
parser.add_argument("--max_length", type=int, default=250, help="Maximum sequence length (PhoBERT has 258 max_position_embeddings so we choose 250)") |
|
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability") |
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=16, help="Training batch size") |
|
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate") |
|
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization") |
|
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs") |
|
parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps") |
|
parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Proportion of training for LR warmup") |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") |
|
parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save the model and logs") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
|
|
|
|
logger.info(f"Running with arguments: {args}") |
|
|
|
num_categories = len(args.label_column) if isinstance(args.label_column, list) else 1 |
|
label_column = args.label_column[0] if isinstance(args.label_column, list) and len(args.label_column) == 1 else args.label_column |
|
|
|
logger.info("Loading and preparing data...") |
|
train_data, _, _ = load_data( |
|
args.train_data_path, |
|
text_col=args.text_column, |
|
label_col=label_column, |
|
validation_split=0.0, |
|
test_split=0.0, |
|
seed=args.seed |
|
) |
|
|
|
_, val_data, _ = load_data( |
|
args.val_data_path, |
|
text_col=args.text_column, |
|
label_col=label_column, |
|
validation_split=1.0, |
|
test_split=0.0, |
|
seed=args.seed |
|
) |
|
|
|
_, _, test_data = load_data( |
|
args.test_data_path, |
|
text_col=args.text_column, |
|
label_col=label_column, |
|
validation_split=0.0, |
|
test_split=1.0, |
|
seed=args.seed |
|
) |
|
|
|
|
|
train_loader, val_loader, test_loader = create_data_loaders( |
|
train_data, |
|
val_data, |
|
test_data, |
|
tokenizer_name=args.bert_model, |
|
max_length=args.max_length, |
|
batch_size=args.batch_size, |
|
num_classes=args.num_classes |
|
) |
|
|
|
logger.info(f"Train samples: {len(train_data[0])}, " |
|
f"Validation samples: {len(val_data[0])}, " |
|
f"Test samples: {len(test_data[0])}") |
|
|
|
|
|
logger.info(f"Initializing DocBERT model with {args.bert_model}...") |
|
model = DocBERT( |
|
num_classes=args.num_classes, |
|
bert_model_name=args.bert_model, |
|
dropout_prob=args.dropout, |
|
num_categories=num_categories |
|
) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logger.info(f"Total parameters: {total_params:,}") |
|
logger.info(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
train_loader=train_loader, |
|
val_loader=val_loader, |
|
test_loader=test_loader, |
|
lr=args.learning_rate, |
|
weight_decay=args.weight_decay, |
|
warmup_proportion=args.warmup_proportion, |
|
gradient_accumulation_steps=args.grad_accum_steps, |
|
num_categories=num_categories, |
|
num_classes=args.num_classes, |
|
) |
|
|
|
|
|
logger.info("Starting training...") |
|
save_path = os.path.join(args.output_dir, args.bert_model.replace("/", "_") + "_finetuned.pth") |
|
trainer.train(epochs=args.epochs, save_path=save_path) |
|
|
|
logger.info("Training completed!") |
|
|
|
if __name__ == "__main__": |
|
main() |