jesse-tong's picture
Increase threshold
99575b1
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
import time
from tqdm import tqdm
import logging
import os
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class Trainer:
"""
Improved trainer class with techniques from Hedwig implementation
to get better performance on document classification tasks
"""
def __init__(
self,
model,
train_loader,
val_loader,
test_loader=None,
lr=2e-5,
weight_decay=0.01,
warmup_proportion=0.1,
gradient_accumulation_steps=1,
max_grad_norm=1.0,
num_classes=2,
num_categories=1,
device=None
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
self.model.to(self.device)
# Total number of training steps
self.num_training_steps = len(train_loader) * gradient_accumulation_steps
# Optimizer with weight decay (L2 regularization)
# Using different learning rates for BERT and classifier
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
self.optimizer = optim.AdamW(optimizer_grouped_parameters, lr=lr)
# Learning rate scheduler
self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=2, verbose=True)
# Loss function with label smoothing for better generalization
self.criterion = nn.CrossEntropyLoss()
# Training parameters
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_grad_norm = max_grad_norm
# For tracking metrics
self.best_val_f1 = 0.0
self.best_model_state = None
self.num_classes = num_classes # Number of classes for classification
# For training if using multiple categories (e.g., multiple sentiment classes, there can be multiple sentiment in one document)
self.num_categories = num_categories
def train(self, epochs, save_path='best_model.pth'):
"""
Training loop with improved techniques
"""
logger.info(f"Starting training for {epochs} epochs")
for epoch in range(epochs):
start_time = time.time()
# Training phase
self.model.train()
train_loss = 0
all_predictions = []
all_labels = []
# Progress bar for training
train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
for i, batch in enumerate(train_iterator):
# Move batch to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
token_type_ids = batch['token_type_ids'].to(self.device)
labels = batch['label'].to(self.device)
# Forward pass
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
# Calculate loss
if self.num_categories > 1:
total_loss = 0
for i in range(self.num_categories):
start_idx = i * self.num_classes
end_idx = (i + 1) * self.num_classes
category_outputs = outputs[:, start_idx:end_idx] # Shape (batch, num_classes)
category_labels = labels[:, i] # Shape (batch)
# Ensure category_labels are in [0, self.num_classes - 1]
if category_labels.max() >= self.num_classes or category_labels.min() < 0:
print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}")
total_loss += self.criterion(category_outputs, category_labels)
loss = total_loss / self.num_categories # Average loss
else:
loss = self.criterion(outputs, labels)
# Scale loss if using gradient accumulation
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
# Backward pass
loss.backward()
# Update weights if we've accumulated enough gradients
if (i + 1) % self.gradient_accumulation_steps == 0:
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
train_loss += loss.item() * self.gradient_accumulation_steps
# Get predictions for metrics
if self.num_categories > 1:
batch_size, total_classes = outputs.shape
if total_classes % self.num_categories != 0:
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}")
classes_per_group = total_classes // self.num_categories
# Group every classes_per_group values along dim=1
reshaped = outputs.view(outputs.size(0), -1, classes_per_group) # shape: (batch, self., classes_per_group)
# Argmax over each group of classes_per_group
preds = reshaped.argmax(dim=-1)
else:
_, preds = torch.max(outputs, dim=1)
all_predictions.extend(preds.cpu().tolist())
all_labels.extend(labels.cpu().tolist())
# Update progress bar with current loss
train_iterator.set_postfix({'loss': f"{loss.item():.4f}"})
# Calculate training metrics
train_loss /= len(self.train_loader)
if self.num_categories > 1:
# Flatten the list of predictions and labels
all_predictions = np.concatenate(all_predictions)
all_labels = np.concatenate(all_labels)
train_acc = accuracy_score(all_labels, all_predictions)
train_f1 = f1_score(all_labels, all_predictions, average='macro')
else:
train_acc = accuracy_score(all_labels, all_predictions)
train_f1 = f1_score(all_labels, all_predictions, average='macro')
# Validation phase
val_loss, val_acc, val_f1, val_precision, val_recall = self.evaluate(self.val_loader, "Validation")
# Log validation metrics
logger.info(f"Validation - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}, "
f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}")
# Adjust learning rate based on validation performance
self.scheduler.step(val_f1)
# Save best model
if val_f1 > self.best_val_f1:
self.best_val_f1 = val_f1
self.best_model_state = self.model.state_dict().copy()
torch.save(self.model.state_dict(), save_path)
logger.info(f"New best model saved with validation F1: {val_f1:.4f}")
# Print epoch summary
epoch_time = time.time() - start_time
logger.info(f"Epoch {epoch+1}/{epochs} - "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, "
f"Time: {epoch_time:.2f}s")
print(f"Epoch {epoch+1}/{epochs} - ",
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, ",
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}, ",
f"Time: {epoch_time:.2f}s")
# Load best model for final evaluation
if self.best_model_state is not None:
self.model.load_state_dict(self.best_model_state)
logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}")
# Test evaluation if test loader provided
if self.test_loader:
test_loss, test_acc, test_f1, test_precision, test_recall = self.evaluate(self.test_loader, "Test")
logger.info(f"Final test results - "
f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, "
f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
print(f"Final test results - ",
f"Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, F1: {test_f1:.4f}, ",
f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")
def evaluate(self, data_loader, phase="Validation", threshold=0.55):
"""
Evaluation function for both validation and test sets
"""
self.model.eval()
eval_loss = 0
all_predictions = np.array([], dtype=int)
all_labels = np.array([], dtype=int)
# No gradient computation during evaluation
with torch.no_grad():
# Progress bar for evaluation
iterator = tqdm(data_loader, desc=f"[{phase}]")
for batch in iterator:
# Move batch to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
token_type_ids = batch['token_type_ids'].to(self.device)
labels = batch['label'].to(self.device)
# Forward pass
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
# Calculate loss
if self.num_categories > 1:
total_loss = 0
for i in range(self.num_categories):
start_idx = i * self.num_classes
end_idx = (i + 1) * self.num_classes
category_outputs = outputs[:, start_idx:end_idx] # Shape (batch, num_classes)
category_labels = labels[:, i] # Shape (batch)
# Ensure category_labels are in [0, self.num_classes - 1]
if category_labels.max() >= self.num_classes or category_labels.min() < 0:
print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}")
total_loss += self.criterion(category_outputs, category_labels)
loss = total_loss / self.num_categories # Average loss
else:
loss = self.criterion(outputs, labels)
eval_loss += loss.item()
# Get predictions
# Get predictions for metrics
if self.num_categories > 1:
batch_size, total_classes = outputs.shape
if total_classes % self.num_categories != 0:
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}")
classes_per_group = total_classes // self.num_categories
# Group every classes_per_group values along dim=1
reshaped = outputs.view(outputs.size(0), -1, classes_per_group) # shape: (batch, self., classes_per_group)
# Softmax and apply threshold
probs = torch.softmax(reshaped, dim=1)
probs = torch.where(probs > threshold, probs, 0.0)
# Argmax over each group of classes_per_group
preds = probs.argmax(dim=-1)
else:
_, preds = torch.max(outputs, dim=1)
all_predictions = np.append(all_predictions, preds.cpu().tolist())
all_labels = np.append(all_labels, labels.cpu().tolist())
# Calculate metrics
eval_loss /= len(data_loader)
accuracy = accuracy_score(all_labels, all_predictions)
f1 = f1_score(all_labels, all_predictions, average='weighted')
precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
return eval_loss, accuracy, f1, precision, recall