2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import os
import json
import torch
import torch.nn.functional as F
from tqdm import tqdm
from accelerate import Accelerator
from .scheduler import create_scheduler
from .metrics import setup_metrics
from .loss_function import MultiClassFocalLossWithAlpha
import wandb
from models.model_factory import create_plm_and_tokenizer
from peft import PeftModel
class Trainer:
def __init__(self, args, model, plm_model, logger):
self.args = args
self.model = model
self.plm_model = plm_model
self.logger = logger
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Setup metrics
self.metrics_dict = setup_metrics(args)
# Setup optimizer with different learning rates
if self.args.training_method == 'full':
# Use a smaller learning rate for PLM
optimizer_grouped_parameters = [
{
"params": self.model.parameters(),
"lr": args.learning_rate
},
{
"params": self.plm_model.parameters(),
"lr": args.learning_rate
}
]
self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
elif self.args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
optimizer_grouped_parameters = [
{
"params": self.model.parameters(),
"lr": args.learning_rate
},
{
"params": [param for param in self.plm_model.parameters() if param.requires_grad],
"lr": args.learning_rate
}
]
self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
else:
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.learning_rate)
# Setup accelerator
self.accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
# Setup scheduler
self.scheduler = create_scheduler(args, self.optimizer)
# Setup loss function
self.loss_fn = self._setup_loss_function()
# Prepare for distributed training
if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
self.model, self.plm_model, self.optimizer = self.accelerator.prepare(
self.model, self.plm_model, self.optimizer
)
else:
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if self.scheduler:
self.scheduler = self.accelerator.prepare(self.scheduler)
# Training state
self.best_val_loss = float("inf")
if self.args.monitor_strategy == 'min':
self.best_val_metric_score = float("inf")
else:
self.best_val_metric_score = -float("inf")
self.global_steps = 0
self.early_stop_counter = 0
# Save args
with open(os.path.join(self.args.output_dir, f'{self.args.output_model_name.split(".")[0]}.json'), 'w') as f:
json.dump(self.args.__dict__, f)
def _setup_loss_function(self):
if self.args.problem_type == 'regression':
return torch.nn.MSELoss()
elif self.args.problem_type == 'multi_label_classification':
return torch.nn.BCEWithLogitsLoss()
else:
return torch.nn.CrossEntropyLoss()
def train(self, train_loader, val_loader):
"""Train the model."""
for epoch in range(self.args.num_epochs):
self.logger.info(f"---------- Epoch {epoch} ----------")
# Training phase
train_loss = self._train_epoch(train_loader)
self.logger.info(f'Epoch {epoch} Train Loss: {train_loss:.4f}')
# Validation phase
val_loss, val_metrics = self._validate(val_loader)
# Handle validation results (model saving, early stopping)
self._handle_validation_results(epoch, val_loss, val_metrics)
# Early stopping check
if self._check_early_stopping():
self.logger.info(f"Early stop at Epoch {epoch}")
break
def _train_epoch(self, train_loader):
self.model.train()
if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
self.plm_model.train()
total_loss = 0
total_samples = 0
epoch_iterator = tqdm(train_loader, desc="Training")
for batch in epoch_iterator:
# choose models to accumulate
models_to_accumulate = [self.model, self.plm_model] if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3'] else [self.model]
with self.accelerator.accumulate(*models_to_accumulate):
# Forward and backward
loss = self._training_step(batch)
self.accelerator.backward(loss)
# Update statistics
batch_size = batch["label"].size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
# Gradient clipping if needed
if self.args.max_grad_norm > 0:
params_to_clip = (
list(self.model.parameters()) + list(self.plm_model.parameters())
if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']
else self.model.parameters()
)
self.accelerator.clip_grad_norm_(params_to_clip, self.args.max_grad_norm)
# Optimization step
self.optimizer.step()
if self.scheduler:
self.scheduler.step()
self.optimizer.zero_grad()
# Logging
self.global_steps += 1
self._log_training_step(loss)
# Update progress bar
epoch_iterator.set_postfix(
train_loss=loss.item(),
grad_step=self.global_steps // self.args.gradient_accumulation_steps
)
return total_loss / total_samples
def _training_step(self, batch):
# Move batch to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass
logits = self.model(self.plm_model, batch)
loss = self._compute_loss(logits, batch["label"])
return loss
def _validate(self, val_loader):
"""
Validate the model.
Args:
val_loader: Validation data loader
Returns:
tuple: (validation_loss, validation_metrics)
"""
self.model.eval()
if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
self.plm_model.eval()
total_loss = 0
total_samples = 0
# Reset all metrics at the start of validation
for metric in self.metrics_dict.values():
metric.reset()
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validating"):
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass
logits = self.model(self.plm_model, batch)
loss = self._compute_loss(logits, batch["label"])
# Update loss statistics
batch_size = len(batch["label"])
total_loss += loss.item() * batch_size
total_samples += batch_size
# Update metrics
self._update_metrics(logits, batch["label"])
# Compute average loss
avg_loss = total_loss / total_samples
# Compute final metrics
metrics_results = {name: metric.compute().item()
for name, metric in self.metrics_dict.items()}
return avg_loss, metrics_results
def test(self, test_loader):
# Load best model
self._load_best_model()
# Add a clear signal that testing is starting
self.logger.info("---------- Starting Test Phase ----------")
# Run evaluation with a custom testing function instead of reusing _validate
test_loss, test_metrics = self._test_evaluate(test_loader)
# Log results
self.logger.info("Test Results:")
self.logger.info(f"Test Loss: {test_loss:.4f}")
for name, value in test_metrics.items():
self.logger.info(f"Test {name}: {value:.4f}")
if self.args.wandb:
wandb.log({f"test/{k}": v for k, v in test_metrics.items()})
wandb.log({"test/loss": test_loss})
def _test_evaluate(self, test_loader):
"""
Dedicated evaluation function for test phase with proper labeling.
This is almost identical to _validate but with "Testing" progress bar.
"""
self.model.eval()
if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']:
self.plm_model.eval()
total_loss = 0
total_samples = 0
# Reset all metrics at the start of testing
for metric in self.metrics_dict.values():
metric.reset()
with torch.no_grad():
# Note the desc is "Testing" instead of "Validating"
for batch in tqdm(test_loader, desc="Testing"):
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass
logits = self.model(self.plm_model, batch)
loss = self._compute_loss(logits, batch["label"])
# Update loss statistics
batch_size = len(batch["label"])
total_loss += loss.item() * batch_size
total_samples += batch_size
# Update metrics
self._update_metrics(logits, batch["label"])
# Compute average loss
avg_loss = total_loss / total_samples
# Compute final metrics
metrics_results = {name: metric.compute().item()
for name, metric in self.metrics_dict.items()}
return avg_loss, metrics_results
def _compute_loss(self, logits, labels):
if self.args.problem_type == 'regression' and self.args.num_labels == 1:
return self.loss_fn(logits.squeeze(), labels.squeeze())
elif self.args.problem_type == 'multi_label_classification':
return self.loss_fn(logits, labels.float())
else:
return self.loss_fn(logits, labels)
def _update_metrics(self, logits, labels):
"""Update metrics with current batch predictions."""
for metric_name, metric in self.metrics_dict.items():
if self.args.problem_type == 'regression' and self.args.num_labels == 1:
logits = logits.view(-1, 1)
labels = labels.view(-1, 1)
metric(logits, labels)
elif self.args.problem_type == 'multi_label_classification':
metric(torch.sigmoid(logits), labels)
else:
if self.args.num_labels == 2:
if metric_name == 'auroc':
metric(torch.sigmoid(logits[:, 1]), labels)
else:
metric(torch.argmax(logits, 1), labels)
else:
if metric_name == 'auroc':
metric(F.softmax(logits, dim=1), labels)
else:
metric(torch.argmax(logits, 1), labels)
def _log_training_step(self, loss):
if self.args.wandb:
wandb.log({
"train/loss": loss.item(),
"train/learning_rate": self.optimizer.param_groups[0]['lr']
}, step=self.global_steps)
# def _save_model(self, path):
# if self.args.training_method in ['full', 'plm-lora']:
# torch.save({
# 'model_state_dict': self.model.state_dict(),
# 'plm_state_dict': self.plm_model.state_dict()
# }, path)
# else:
# torch.save(self.model.state_dict(), path)
# def _load_best_model(self):
# path = os.path.join(self.args.output_dir, self.args.output_model_name)
# if self.args.training_method in ['full', 'plm-lora']:
# checkpoint = torch.load(path, weights_only=True)
# self.model.load_state_dict(checkpoint['model_state_dict'])
# self.plm_model.load_state_dict(checkpoint['plm_state_dict'])
# else:
# self.model.load_state_dict(torch.load(path, weights_only=True))
def _save_model(self, path):
if self.args.training_method in ['full', 'lora']:
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
plm_state = {k: v.cpu() for k, v in self.plm_model.state_dict().items()}
torch.save({
'model_state_dict': model_state,
'plm_state_dict': plm_state
}, path)
elif self.args.training_method == "plm-lora":
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
torch.save(model_state, path)
plm_lora_path = path.replace('.pt', '_lora')
self.plm_model.save_pretrained(plm_lora_path)
elif self.args.training_method == "plm-qlora":
# save model state dict
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
torch.save(model_state, path)
plm_qlora_path = path.replace('.pt', '_qlora')
# save plm model lora weights
self.plm_model.save_pretrained(plm_qlora_path)
elif self.args.training_method == "plm-dora":
# save model state dict
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
torch.save(model_state, path)
plm_dora_path = path.replace('.pt', '_dora')
# save plm model lora weights
self.plm_model.save_pretrained(plm_dora_path)
elif self.args.training_method == "plm-adalora":
# save model state dict
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
torch.save(model_state, path)
plm_adalora_path = path.replace('.pt', '_adalora')
self.plm_model.save_pretrained(plm_adalora_path)
elif self.args.training_method == "plm-ia3":
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
torch.save(model_state, path)
plm_ia3_path = path.replace('.pt', '_ia3')
self.plm_model.save_pretrained(plm_ia3_path)
else:
model_state = {k: v.cpu() for k, v in self.model.state_dict().items()}
torch.save(model_state, path)
def _load_best_model(self):
path = os.path.join(self.args.output_dir, self.args.output_model_name)
if self.args.training_method in ['full', 'lora']:
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint['model_state_dict'])
self.plm_model.load_state_dict(checkpoint['plm_state_dict'])
self.model.to(self.device)
self.plm_model.to(self.device)
elif self.args.training_method == "plm-lora":
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint)
plm_lora_path = path.replace('.pt', '_lora')
_, self.plm_model = create_plm_and_tokenizer(self.args)
self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_lora_path)
self.plm_model = self.plm_model.merge_and_unload()
self.model.to(self.device)
self.plm_model.to(self.device)
elif self.args.training_method == "plm-qlora":
# load model state dict
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint)
plm_qlora_path = path.replace('.pt', '_qlora')
# reload plm model and apply qlora weights
_, self.plm_model = create_plm_and_tokenizer(self.args)
self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_qlora_path)
self.plm_model = self.plm_model.merge_and_unload()
self.model.to(self.device)
self.plm_model.to(self.device)
elif self.args.training_method == "plm-dora":
# load model state dict
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint)
plm_dora_path = path.replace('.pt', '_dora')
# reload plm model and apply dora weights
_, self.plm_model = create_plm_and_tokenizer(self.args)
self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_dora_path)
self.plm_model = self.plm_model.merge_and_unload()
self.model.to(self.device)
self.plm_model.to(self.device)
elif self.args.training_method == "plm-adalora":
# load model state dict
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint)
plm_adalora_path = path.replace('.pt', '_adalora')
# reload plm model and apply adalora weights
_, self.plm_model = create_plm_and_tokenizer(self.args)
self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_adalora_path)
self.plm_model = self.plm_model.merge_and_unload()
self.model.to(self.device)
self.plm_model.to(self.device)
elif self.args.training_method == "plm-ia3":
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint)
plm_ia3_path = path.replace('.pt', '_ia3')
_, self.plm_model = create_plm_and_tokenizer(self.args)
self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_ia3_path)
self.plm_model = self.plm_model.merge_and_unload()
self.model.to(self.device)
self.plm_model.to(self.device)
else:
checkpoint = torch.load(path, map_location="cpu")
self.model.load_state_dict(checkpoint)
self.model.to(self.device)
def _handle_validation_results(self, epoch: int, val_loss: float, val_metrics: dict):
"""
Handle validation results, including model saving and early stopping checks.
Args:
epoch: Current epoch number
val_loss: Validation loss
val_metrics: Dictionary of validation metrics
"""
# Log validation results
self.logger.info(f'Epoch {epoch} Val Loss: {val_loss:.4f}')
for metric_name, metric_value in val_metrics.items():
self.logger.info(f'Epoch {epoch} Val {metric_name}: {metric_value:.4f}')
if self.args.wandb:
wandb.log({
"val/loss": val_loss,
**{f"val/{k}": v for k, v in val_metrics.items()}
}, step=self.global_steps)
# Check if we should save the model
should_save = False
monitor_value = val_loss
# If monitoring a specific metric
if self.args.monitor != 'loss' and self.args.monitor in val_metrics:
monitor_value = val_metrics[self.args.monitor]
# Check if current result is better
if self.args.monitor_strategy == 'min':
if monitor_value < self.best_val_metric_score:
should_save = True
self.best_val_metric_score = monitor_value
self.early_stop_counter = 0
else:
self.early_stop_counter += 1
else: # strategy == 'max'
if monitor_value > self.best_val_metric_score:
should_save = True
self.best_val_metric_score = monitor_value
self.early_stop_counter = 0
else:
self.early_stop_counter += 1
# Save model if improved
if should_save:
self.logger.info(f"Saving model with best val {self.args.monitor}: {monitor_value:.4f}")
save_path = os.path.join(self.args.output_dir, self.args.output_model_name)
self._save_model(save_path)
def _check_early_stopping(self) -> bool:
"""
Check if training should be stopped early.
Returns:
bool: True if training should stop, False otherwise
"""
if self.args.patience > 0 and self.early_stop_counter >= self.args.patience:
self.logger.info(f"Early stopping triggered after {self.early_stop_counter} epochs without improvement")
return True
return False