import os import json import torch import torch.nn.functional as F from tqdm import tqdm from accelerate import Accelerator from .scheduler import create_scheduler from .metrics import setup_metrics from .loss_function import MultiClassFocalLossWithAlpha import wandb from models.model_factory import create_plm_and_tokenizer from peft import PeftModel class Trainer: def __init__(self, args, model, plm_model, logger): self.args = args self.model = model self.plm_model = plm_model self.logger = logger self.device = "cuda" if torch.cuda.is_available() else "cpu" # Setup metrics self.metrics_dict = setup_metrics(args) # Setup optimizer with different learning rates if self.args.training_method == 'full': # Use a smaller learning rate for PLM optimizer_grouped_parameters = [ { "params": self.model.parameters(), "lr": args.learning_rate }, { "params": self.plm_model.parameters(), "lr": args.learning_rate } ] self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters) elif self.args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: optimizer_grouped_parameters = [ { "params": self.model.parameters(), "lr": args.learning_rate }, { "params": [param for param in self.plm_model.parameters() if param.requires_grad], "lr": args.learning_rate } ] self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters) else: self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.learning_rate) # Setup accelerator self.accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) # Setup scheduler self.scheduler = create_scheduler(args, self.optimizer) # Setup loss function self.loss_fn = self._setup_loss_function() # Prepare for distributed training if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: self.model, self.plm_model, self.optimizer = self.accelerator.prepare( self.model, self.plm_model, self.optimizer ) else: self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) if self.scheduler: self.scheduler = self.accelerator.prepare(self.scheduler) # Training state self.best_val_loss = float("inf") if self.args.monitor_strategy == 'min': self.best_val_metric_score = float("inf") else: self.best_val_metric_score = -float("inf") self.global_steps = 0 self.early_stop_counter = 0 # Save args with open(os.path.join(self.args.output_dir, f'{self.args.output_model_name.split(".")[0]}.json'), 'w') as f: json.dump(self.args.__dict__, f) def _setup_loss_function(self): if self.args.problem_type == 'regression': return torch.nn.MSELoss() elif self.args.problem_type == 'multi_label_classification': return torch.nn.BCEWithLogitsLoss() else: return torch.nn.CrossEntropyLoss() def train(self, train_loader, val_loader): """Train the model.""" for epoch in range(self.args.num_epochs): self.logger.info(f"---------- Epoch {epoch} ----------") # Training phase train_loss = self._train_epoch(train_loader) self.logger.info(f'Epoch {epoch} Train Loss: {train_loss:.4f}') # Validation phase val_loss, val_metrics = self._validate(val_loader) # Handle validation results (model saving, early stopping) self._handle_validation_results(epoch, val_loss, val_metrics) # Early stopping check if self._check_early_stopping(): self.logger.info(f"Early stop at Epoch {epoch}") break def _train_epoch(self, train_loader): self.model.train() if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: self.plm_model.train() total_loss = 0 total_samples = 0 epoch_iterator = tqdm(train_loader, desc="Training") for batch in epoch_iterator: # choose models to accumulate models_to_accumulate = [self.model, self.plm_model] if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3'] else [self.model] with self.accelerator.accumulate(*models_to_accumulate): # Forward and backward loss = self._training_step(batch) self.accelerator.backward(loss) # Update statistics batch_size = batch["label"].size(0) total_loss += loss.item() * batch_size total_samples += batch_size # Gradient clipping if needed if self.args.max_grad_norm > 0: params_to_clip = ( list(self.model.parameters()) + list(self.plm_model.parameters()) if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3'] else self.model.parameters() ) self.accelerator.clip_grad_norm_(params_to_clip, self.args.max_grad_norm) # Optimization step self.optimizer.step() if self.scheduler: self.scheduler.step() self.optimizer.zero_grad() # Logging self.global_steps += 1 self._log_training_step(loss) # Update progress bar epoch_iterator.set_postfix( train_loss=loss.item(), grad_step=self.global_steps // self.args.gradient_accumulation_steps ) return total_loss / total_samples def _training_step(self, batch): # Move batch to device batch = {k: v.to(self.device) for k, v in batch.items()} # Forward pass logits = self.model(self.plm_model, batch) loss = self._compute_loss(logits, batch["label"]) return loss def _validate(self, val_loader): """ Validate the model. Args: val_loader: Validation data loader Returns: tuple: (validation_loss, validation_metrics) """ self.model.eval() if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: self.plm_model.eval() total_loss = 0 total_samples = 0 # Reset all metrics at the start of validation for metric in self.metrics_dict.values(): metric.reset() with torch.no_grad(): for batch in tqdm(val_loader, desc="Validating"): batch = {k: v.to(self.device) for k, v in batch.items()} # Forward pass logits = self.model(self.plm_model, batch) loss = self._compute_loss(logits, batch["label"]) # Update loss statistics batch_size = len(batch["label"]) total_loss += loss.item() * batch_size total_samples += batch_size # Update metrics self._update_metrics(logits, batch["label"]) # Compute average loss avg_loss = total_loss / total_samples # Compute final metrics metrics_results = {name: metric.compute().item() for name, metric in self.metrics_dict.items()} return avg_loss, metrics_results def test(self, test_loader): # Load best model self._load_best_model() # Add a clear signal that testing is starting self.logger.info("---------- Starting Test Phase ----------") # Run evaluation with a custom testing function instead of reusing _validate test_loss, test_metrics = self._test_evaluate(test_loader) # Log results self.logger.info("Test Results:") self.logger.info(f"Test Loss: {test_loss:.4f}") for name, value in test_metrics.items(): self.logger.info(f"Test {name}: {value:.4f}") if self.args.wandb: wandb.log({f"test/{k}": v for k, v in test_metrics.items()}) wandb.log({"test/loss": test_loss}) def _test_evaluate(self, test_loader): """ Dedicated evaluation function for test phase with proper labeling. This is almost identical to _validate but with "Testing" progress bar. """ self.model.eval() if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: self.plm_model.eval() total_loss = 0 total_samples = 0 # Reset all metrics at the start of testing for metric in self.metrics_dict.values(): metric.reset() with torch.no_grad(): # Note the desc is "Testing" instead of "Validating" for batch in tqdm(test_loader, desc="Testing"): batch = {k: v.to(self.device) for k, v in batch.items()} # Forward pass logits = self.model(self.plm_model, batch) loss = self._compute_loss(logits, batch["label"]) # Update loss statistics batch_size = len(batch["label"]) total_loss += loss.item() * batch_size total_samples += batch_size # Update metrics self._update_metrics(logits, batch["label"]) # Compute average loss avg_loss = total_loss / total_samples # Compute final metrics metrics_results = {name: metric.compute().item() for name, metric in self.metrics_dict.items()} return avg_loss, metrics_results def _compute_loss(self, logits, labels): if self.args.problem_type == 'regression' and self.args.num_labels == 1: return self.loss_fn(logits.squeeze(), labels.squeeze()) elif self.args.problem_type == 'multi_label_classification': return self.loss_fn(logits, labels.float()) else: return self.loss_fn(logits, labels) def _update_metrics(self, logits, labels): """Update metrics with current batch predictions.""" for metric_name, metric in self.metrics_dict.items(): if self.args.problem_type == 'regression' and self.args.num_labels == 1: logits = logits.view(-1, 1) labels = labels.view(-1, 1) metric(logits, labels) elif self.args.problem_type == 'multi_label_classification': metric(torch.sigmoid(logits), labels) else: if self.args.num_labels == 2: if metric_name == 'auroc': metric(torch.sigmoid(logits[:, 1]), labels) else: metric(torch.argmax(logits, 1), labels) else: if metric_name == 'auroc': metric(F.softmax(logits, dim=1), labels) else: metric(torch.argmax(logits, 1), labels) def _log_training_step(self, loss): if self.args.wandb: wandb.log({ "train/loss": loss.item(), "train/learning_rate": self.optimizer.param_groups[0]['lr'] }, step=self.global_steps) # def _save_model(self, path): # if self.args.training_method in ['full', 'plm-lora']: # torch.save({ # 'model_state_dict': self.model.state_dict(), # 'plm_state_dict': self.plm_model.state_dict() # }, path) # else: # torch.save(self.model.state_dict(), path) # def _load_best_model(self): # path = os.path.join(self.args.output_dir, self.args.output_model_name) # if self.args.training_method in ['full', 'plm-lora']: # checkpoint = torch.load(path, weights_only=True) # self.model.load_state_dict(checkpoint['model_state_dict']) # self.plm_model.load_state_dict(checkpoint['plm_state_dict']) # else: # self.model.load_state_dict(torch.load(path, weights_only=True)) def _save_model(self, path): if self.args.training_method in ['full', 'lora']: model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} plm_state = {k: v.cpu() for k, v in self.plm_model.state_dict().items()} torch.save({ 'model_state_dict': model_state, 'plm_state_dict': plm_state }, path) elif self.args.training_method == "plm-lora": model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} torch.save(model_state, path) plm_lora_path = path.replace('.pt', '_lora') self.plm_model.save_pretrained(plm_lora_path) elif self.args.training_method == "plm-qlora": # save model state dict model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} torch.save(model_state, path) plm_qlora_path = path.replace('.pt', '_qlora') # save plm model lora weights self.plm_model.save_pretrained(plm_qlora_path) elif self.args.training_method == "plm-dora": # save model state dict model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} torch.save(model_state, path) plm_dora_path = path.replace('.pt', '_dora') # save plm model lora weights self.plm_model.save_pretrained(plm_dora_path) elif self.args.training_method == "plm-adalora": # save model state dict model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} torch.save(model_state, path) plm_adalora_path = path.replace('.pt', '_adalora') self.plm_model.save_pretrained(plm_adalora_path) elif self.args.training_method == "plm-ia3": model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} torch.save(model_state, path) plm_ia3_path = path.replace('.pt', '_ia3') self.plm_model.save_pretrained(plm_ia3_path) else: model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} torch.save(model_state, path) def _load_best_model(self): path = os.path.join(self.args.output_dir, self.args.output_model_name) if self.args.training_method in ['full', 'lora']: checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint['model_state_dict']) self.plm_model.load_state_dict(checkpoint['plm_state_dict']) self.model.to(self.device) self.plm_model.to(self.device) elif self.args.training_method == "plm-lora": checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint) plm_lora_path = path.replace('.pt', '_lora') _, self.plm_model = create_plm_and_tokenizer(self.args) self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_lora_path) self.plm_model = self.plm_model.merge_and_unload() self.model.to(self.device) self.plm_model.to(self.device) elif self.args.training_method == "plm-qlora": # load model state dict checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint) plm_qlora_path = path.replace('.pt', '_qlora') # reload plm model and apply qlora weights _, self.plm_model = create_plm_and_tokenizer(self.args) self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_qlora_path) self.plm_model = self.plm_model.merge_and_unload() self.model.to(self.device) self.plm_model.to(self.device) elif self.args.training_method == "plm-dora": # load model state dict checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint) plm_dora_path = path.replace('.pt', '_dora') # reload plm model and apply dora weights _, self.plm_model = create_plm_and_tokenizer(self.args) self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_dora_path) self.plm_model = self.plm_model.merge_and_unload() self.model.to(self.device) self.plm_model.to(self.device) elif self.args.training_method == "plm-adalora": # load model state dict checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint) plm_adalora_path = path.replace('.pt', '_adalora') # reload plm model and apply adalora weights _, self.plm_model = create_plm_and_tokenizer(self.args) self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_adalora_path) self.plm_model = self.plm_model.merge_and_unload() self.model.to(self.device) self.plm_model.to(self.device) elif self.args.training_method == "plm-ia3": checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint) plm_ia3_path = path.replace('.pt', '_ia3') _, self.plm_model = create_plm_and_tokenizer(self.args) self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_ia3_path) self.plm_model = self.plm_model.merge_and_unload() self.model.to(self.device) self.plm_model.to(self.device) else: checkpoint = torch.load(path, map_location="cpu") self.model.load_state_dict(checkpoint) self.model.to(self.device) def _handle_validation_results(self, epoch: int, val_loss: float, val_metrics: dict): """ Handle validation results, including model saving and early stopping checks. Args: epoch: Current epoch number val_loss: Validation loss val_metrics: Dictionary of validation metrics """ # Log validation results self.logger.info(f'Epoch {epoch} Val Loss: {val_loss:.4f}') for metric_name, metric_value in val_metrics.items(): self.logger.info(f'Epoch {epoch} Val {metric_name}: {metric_value:.4f}') if self.args.wandb: wandb.log({ "val/loss": val_loss, **{f"val/{k}": v for k, v in val_metrics.items()} }, step=self.global_steps) # Check if we should save the model should_save = False monitor_value = val_loss # If monitoring a specific metric if self.args.monitor != 'loss' and self.args.monitor in val_metrics: monitor_value = val_metrics[self.args.monitor] # Check if current result is better if self.args.monitor_strategy == 'min': if monitor_value < self.best_val_metric_score: should_save = True self.best_val_metric_score = monitor_value self.early_stop_counter = 0 else: self.early_stop_counter += 1 else: # strategy == 'max' if monitor_value > self.best_val_metric_score: should_save = True self.best_val_metric_score = monitor_value self.early_stop_counter = 0 else: self.early_stop_counter += 1 # Save model if improved if should_save: self.logger.info(f"Saving model with best val {self.args.monitor}: {monitor_value:.4f}") save_path = os.path.join(self.args.output_dir, self.args.output_model_name) self._save_model(save_path) def _check_early_stopping(self) -> bool: """ Check if training should be stopped early. Returns: bool: True if training should stop, False otherwise """ if self.args.patience > 0 and self.early_stop_counter >= self.args.patience: self.logger.info(f"Early stopping triggered after {self.early_stop_counter} epochs without improvement") return True return False