Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from accelerate import Accelerator | |
| from .scheduler import create_scheduler | |
| from .metrics import setup_metrics | |
| from .loss_function import MultiClassFocalLossWithAlpha | |
| import wandb | |
| from models.model_factory import create_plm_and_tokenizer | |
| from peft import PeftModel | |
| class Trainer: | |
| def __init__(self, args, model, plm_model, logger): | |
| self.args = args | |
| self.model = model | |
| self.plm_model = plm_model | |
| self.logger = logger | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Setup metrics | |
| self.metrics_dict = setup_metrics(args) | |
| # Setup optimizer with different learning rates | |
| if self.args.training_method == 'full': | |
| # Use a smaller learning rate for PLM | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": self.model.parameters(), | |
| "lr": args.learning_rate | |
| }, | |
| { | |
| "params": self.plm_model.parameters(), | |
| "lr": args.learning_rate | |
| } | |
| ] | |
| self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters) | |
| elif self.args.training_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": self.model.parameters(), | |
| "lr": args.learning_rate | |
| }, | |
| { | |
| "params": [param for param in self.plm_model.parameters() if param.requires_grad], | |
| "lr": args.learning_rate | |
| } | |
| ] | |
| self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters) | |
| else: | |
| self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.learning_rate) | |
| # Setup accelerator | |
| self.accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) | |
| # Setup scheduler | |
| self.scheduler = create_scheduler(args, self.optimizer) | |
| # Setup loss function | |
| self.loss_fn = self._setup_loss_function() | |
| # Prepare for distributed training | |
| if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| self.model, self.plm_model, self.optimizer = self.accelerator.prepare( | |
| self.model, self.plm_model, self.optimizer | |
| ) | |
| else: | |
| self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) | |
| if self.scheduler: | |
| self.scheduler = self.accelerator.prepare(self.scheduler) | |
| # Training state | |
| self.best_val_loss = float("inf") | |
| if self.args.monitor_strategy == 'min': | |
| self.best_val_metric_score = float("inf") | |
| else: | |
| self.best_val_metric_score = -float("inf") | |
| self.global_steps = 0 | |
| self.early_stop_counter = 0 | |
| # Save args | |
| with open(os.path.join(self.args.output_dir, f'{self.args.output_model_name.split(".")[0]}.json'), 'w') as f: | |
| json.dump(self.args.__dict__, f) | |
| def _setup_loss_function(self): | |
| if self.args.problem_type == 'regression': | |
| return torch.nn.MSELoss() | |
| elif self.args.problem_type == 'multi_label_classification': | |
| return torch.nn.BCEWithLogitsLoss() | |
| else: | |
| return torch.nn.CrossEntropyLoss() | |
| def train(self, train_loader, val_loader): | |
| """Train the model.""" | |
| for epoch in range(self.args.num_epochs): | |
| self.logger.info(f"---------- Epoch {epoch} ----------") | |
| # Training phase | |
| train_loss = self._train_epoch(train_loader) | |
| self.logger.info(f'Epoch {epoch} Train Loss: {train_loss:.4f}') | |
| # Validation phase | |
| val_loss, val_metrics = self._validate(val_loader) | |
| # Handle validation results (model saving, early stopping) | |
| self._handle_validation_results(epoch, val_loss, val_metrics) | |
| # Early stopping check | |
| if self._check_early_stopping(): | |
| self.logger.info(f"Early stop at Epoch {epoch}") | |
| break | |
| def _train_epoch(self, train_loader): | |
| self.model.train() | |
| if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| self.plm_model.train() | |
| total_loss = 0 | |
| total_samples = 0 | |
| epoch_iterator = tqdm(train_loader, desc="Training") | |
| for batch in epoch_iterator: | |
| # choose models to accumulate | |
| models_to_accumulate = [self.model, self.plm_model] if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3'] else [self.model] | |
| with self.accelerator.accumulate(*models_to_accumulate): | |
| # Forward and backward | |
| loss = self._training_step(batch) | |
| self.accelerator.backward(loss) | |
| # Update statistics | |
| batch_size = batch["label"].size(0) | |
| total_loss += loss.item() * batch_size | |
| total_samples += batch_size | |
| # Gradient clipping if needed | |
| if self.args.max_grad_norm > 0: | |
| params_to_clip = ( | |
| list(self.model.parameters()) + list(self.plm_model.parameters()) | |
| if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3'] | |
| else self.model.parameters() | |
| ) | |
| self.accelerator.clip_grad_norm_(params_to_clip, self.args.max_grad_norm) | |
| # Optimization step | |
| self.optimizer.step() | |
| if self.scheduler: | |
| self.scheduler.step() | |
| self.optimizer.zero_grad() | |
| # Logging | |
| self.global_steps += 1 | |
| self._log_training_step(loss) | |
| # Update progress bar | |
| epoch_iterator.set_postfix( | |
| train_loss=loss.item(), | |
| grad_step=self.global_steps // self.args.gradient_accumulation_steps | |
| ) | |
| return total_loss / total_samples | |
| def _training_step(self, batch): | |
| # Move batch to device | |
| batch = {k: v.to(self.device) for k, v in batch.items()} | |
| # Forward pass | |
| logits = self.model(self.plm_model, batch) | |
| loss = self._compute_loss(logits, batch["label"]) | |
| return loss | |
| def _validate(self, val_loader): | |
| """ | |
| Validate the model. | |
| Args: | |
| val_loader: Validation data loader | |
| Returns: | |
| tuple: (validation_loss, validation_metrics) | |
| """ | |
| self.model.eval() | |
| if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| self.plm_model.eval() | |
| total_loss = 0 | |
| total_samples = 0 | |
| # Reset all metrics at the start of validation | |
| for metric in self.metrics_dict.values(): | |
| metric.reset() | |
| with torch.no_grad(): | |
| for batch in tqdm(val_loader, desc="Validating"): | |
| batch = {k: v.to(self.device) for k, v in batch.items()} | |
| # Forward pass | |
| logits = self.model(self.plm_model, batch) | |
| loss = self._compute_loss(logits, batch["label"]) | |
| # Update loss statistics | |
| batch_size = len(batch["label"]) | |
| total_loss += loss.item() * batch_size | |
| total_samples += batch_size | |
| # Update metrics | |
| self._update_metrics(logits, batch["label"]) | |
| # Compute average loss | |
| avg_loss = total_loss / total_samples | |
| # Compute final metrics | |
| metrics_results = {name: metric.compute().item() | |
| for name, metric in self.metrics_dict.items()} | |
| return avg_loss, metrics_results | |
| def test(self, test_loader): | |
| # Load best model | |
| self._load_best_model() | |
| # Add a clear signal that testing is starting | |
| self.logger.info("---------- Starting Test Phase ----------") | |
| # Run evaluation with a custom testing function instead of reusing _validate | |
| test_loss, test_metrics = self._test_evaluate(test_loader) | |
| # Log results | |
| self.logger.info("Test Results:") | |
| self.logger.info(f"Test Loss: {test_loss:.4f}") | |
| for name, value in test_metrics.items(): | |
| self.logger.info(f"Test {name}: {value:.4f}") | |
| if self.args.wandb: | |
| wandb.log({f"test/{k}": v for k, v in test_metrics.items()}) | |
| wandb.log({"test/loss": test_loss}) | |
| def _test_evaluate(self, test_loader): | |
| """ | |
| Dedicated evaluation function for test phase with proper labeling. | |
| This is almost identical to _validate but with "Testing" progress bar. | |
| """ | |
| self.model.eval() | |
| if self.args.training_method in ['full', 'plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| self.plm_model.eval() | |
| total_loss = 0 | |
| total_samples = 0 | |
| # Reset all metrics at the start of testing | |
| for metric in self.metrics_dict.values(): | |
| metric.reset() | |
| with torch.no_grad(): | |
| # Note the desc is "Testing" instead of "Validating" | |
| for batch in tqdm(test_loader, desc="Testing"): | |
| batch = {k: v.to(self.device) for k, v in batch.items()} | |
| # Forward pass | |
| logits = self.model(self.plm_model, batch) | |
| loss = self._compute_loss(logits, batch["label"]) | |
| # Update loss statistics | |
| batch_size = len(batch["label"]) | |
| total_loss += loss.item() * batch_size | |
| total_samples += batch_size | |
| # Update metrics | |
| self._update_metrics(logits, batch["label"]) | |
| # Compute average loss | |
| avg_loss = total_loss / total_samples | |
| # Compute final metrics | |
| metrics_results = {name: metric.compute().item() | |
| for name, metric in self.metrics_dict.items()} | |
| return avg_loss, metrics_results | |
| def _compute_loss(self, logits, labels): | |
| if self.args.problem_type == 'regression' and self.args.num_labels == 1: | |
| return self.loss_fn(logits.squeeze(), labels.squeeze()) | |
| elif self.args.problem_type == 'multi_label_classification': | |
| return self.loss_fn(logits, labels.float()) | |
| else: | |
| return self.loss_fn(logits, labels) | |
| def _update_metrics(self, logits, labels): | |
| """Update metrics with current batch predictions.""" | |
| for metric_name, metric in self.metrics_dict.items(): | |
| if self.args.problem_type == 'regression' and self.args.num_labels == 1: | |
| logits = logits.view(-1, 1) | |
| labels = labels.view(-1, 1) | |
| metric(logits, labels) | |
| elif self.args.problem_type == 'multi_label_classification': | |
| metric(torch.sigmoid(logits), labels) | |
| else: | |
| if self.args.num_labels == 2: | |
| if metric_name == 'auroc': | |
| metric(torch.sigmoid(logits[:, 1]), labels) | |
| else: | |
| metric(torch.argmax(logits, 1), labels) | |
| else: | |
| if metric_name == 'auroc': | |
| metric(F.softmax(logits, dim=1), labels) | |
| else: | |
| metric(torch.argmax(logits, 1), labels) | |
| def _log_training_step(self, loss): | |
| if self.args.wandb: | |
| wandb.log({ | |
| "train/loss": loss.item(), | |
| "train/learning_rate": self.optimizer.param_groups[0]['lr'] | |
| }, step=self.global_steps) | |
| # def _save_model(self, path): | |
| # if self.args.training_method in ['full', 'plm-lora']: | |
| # torch.save({ | |
| # 'model_state_dict': self.model.state_dict(), | |
| # 'plm_state_dict': self.plm_model.state_dict() | |
| # }, path) | |
| # else: | |
| # torch.save(self.model.state_dict(), path) | |
| # def _load_best_model(self): | |
| # path = os.path.join(self.args.output_dir, self.args.output_model_name) | |
| # if self.args.training_method in ['full', 'plm-lora']: | |
| # checkpoint = torch.load(path, weights_only=True) | |
| # self.model.load_state_dict(checkpoint['model_state_dict']) | |
| # self.plm_model.load_state_dict(checkpoint['plm_state_dict']) | |
| # else: | |
| # self.model.load_state_dict(torch.load(path, weights_only=True)) | |
| def _save_model(self, path): | |
| if self.args.training_method in ['full', 'lora']: | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| plm_state = {k: v.cpu() for k, v in self.plm_model.state_dict().items()} | |
| torch.save({ | |
| 'model_state_dict': model_state, | |
| 'plm_state_dict': plm_state | |
| }, path) | |
| elif self.args.training_method == "plm-lora": | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| torch.save(model_state, path) | |
| plm_lora_path = path.replace('.pt', '_lora') | |
| self.plm_model.save_pretrained(plm_lora_path) | |
| elif self.args.training_method == "plm-qlora": | |
| # save model state dict | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| torch.save(model_state, path) | |
| plm_qlora_path = path.replace('.pt', '_qlora') | |
| # save plm model lora weights | |
| self.plm_model.save_pretrained(plm_qlora_path) | |
| elif self.args.training_method == "plm-dora": | |
| # save model state dict | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| torch.save(model_state, path) | |
| plm_dora_path = path.replace('.pt', '_dora') | |
| # save plm model lora weights | |
| self.plm_model.save_pretrained(plm_dora_path) | |
| elif self.args.training_method == "plm-adalora": | |
| # save model state dict | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| torch.save(model_state, path) | |
| plm_adalora_path = path.replace('.pt', '_adalora') | |
| self.plm_model.save_pretrained(plm_adalora_path) | |
| elif self.args.training_method == "plm-ia3": | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| torch.save(model_state, path) | |
| plm_ia3_path = path.replace('.pt', '_ia3') | |
| self.plm_model.save_pretrained(plm_ia3_path) | |
| else: | |
| model_state = {k: v.cpu() for k, v in self.model.state_dict().items()} | |
| torch.save(model_state, path) | |
| def _load_best_model(self): | |
| path = os.path.join(self.args.output_dir, self.args.output_model_name) | |
| if self.args.training_method in ['full', 'lora']: | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.plm_model.load_state_dict(checkpoint['plm_state_dict']) | |
| self.model.to(self.device) | |
| self.plm_model.to(self.device) | |
| elif self.args.training_method == "plm-lora": | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint) | |
| plm_lora_path = path.replace('.pt', '_lora') | |
| _, self.plm_model = create_plm_and_tokenizer(self.args) | |
| self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_lora_path) | |
| self.plm_model = self.plm_model.merge_and_unload() | |
| self.model.to(self.device) | |
| self.plm_model.to(self.device) | |
| elif self.args.training_method == "plm-qlora": | |
| # load model state dict | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint) | |
| plm_qlora_path = path.replace('.pt', '_qlora') | |
| # reload plm model and apply qlora weights | |
| _, self.plm_model = create_plm_and_tokenizer(self.args) | |
| self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_qlora_path) | |
| self.plm_model = self.plm_model.merge_and_unload() | |
| self.model.to(self.device) | |
| self.plm_model.to(self.device) | |
| elif self.args.training_method == "plm-dora": | |
| # load model state dict | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint) | |
| plm_dora_path = path.replace('.pt', '_dora') | |
| # reload plm model and apply dora weights | |
| _, self.plm_model = create_plm_and_tokenizer(self.args) | |
| self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_dora_path) | |
| self.plm_model = self.plm_model.merge_and_unload() | |
| self.model.to(self.device) | |
| self.plm_model.to(self.device) | |
| elif self.args.training_method == "plm-adalora": | |
| # load model state dict | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint) | |
| plm_adalora_path = path.replace('.pt', '_adalora') | |
| # reload plm model and apply adalora weights | |
| _, self.plm_model = create_plm_and_tokenizer(self.args) | |
| self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_adalora_path) | |
| self.plm_model = self.plm_model.merge_and_unload() | |
| self.model.to(self.device) | |
| self.plm_model.to(self.device) | |
| elif self.args.training_method == "plm-ia3": | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint) | |
| plm_ia3_path = path.replace('.pt', '_ia3') | |
| _, self.plm_model = create_plm_and_tokenizer(self.args) | |
| self.plm_model = PeftModel.from_pretrained(self.plm_model, plm_ia3_path) | |
| self.plm_model = self.plm_model.merge_and_unload() | |
| self.model.to(self.device) | |
| self.plm_model.to(self.device) | |
| else: | |
| checkpoint = torch.load(path, map_location="cpu") | |
| self.model.load_state_dict(checkpoint) | |
| self.model.to(self.device) | |
| def _handle_validation_results(self, epoch: int, val_loss: float, val_metrics: dict): | |
| """ | |
| Handle validation results, including model saving and early stopping checks. | |
| Args: | |
| epoch: Current epoch number | |
| val_loss: Validation loss | |
| val_metrics: Dictionary of validation metrics | |
| """ | |
| # Log validation results | |
| self.logger.info(f'Epoch {epoch} Val Loss: {val_loss:.4f}') | |
| for metric_name, metric_value in val_metrics.items(): | |
| self.logger.info(f'Epoch {epoch} Val {metric_name}: {metric_value:.4f}') | |
| if self.args.wandb: | |
| wandb.log({ | |
| "val/loss": val_loss, | |
| **{f"val/{k}": v for k, v in val_metrics.items()} | |
| }, step=self.global_steps) | |
| # Check if we should save the model | |
| should_save = False | |
| monitor_value = val_loss | |
| # If monitoring a specific metric | |
| if self.args.monitor != 'loss' and self.args.monitor in val_metrics: | |
| monitor_value = val_metrics[self.args.monitor] | |
| # Check if current result is better | |
| if self.args.monitor_strategy == 'min': | |
| if monitor_value < self.best_val_metric_score: | |
| should_save = True | |
| self.best_val_metric_score = monitor_value | |
| self.early_stop_counter = 0 | |
| else: | |
| self.early_stop_counter += 1 | |
| else: # strategy == 'max' | |
| if monitor_value > self.best_val_metric_score: | |
| should_save = True | |
| self.best_val_metric_score = monitor_value | |
| self.early_stop_counter = 0 | |
| else: | |
| self.early_stop_counter += 1 | |
| # Save model if improved | |
| if should_save: | |
| self.logger.info(f"Saving model with best val {self.args.monitor}: {monitor_value:.4f}") | |
| save_path = os.path.join(self.args.output_dir, self.args.output_model_name) | |
| self._save_model(save_path) | |
| def _check_early_stopping(self) -> bool: | |
| """ | |
| Check if training should be stopped early. | |
| Returns: | |
| bool: True if training should stop, False otherwise | |
| """ | |
| if self.args.patience > 0 and self.early_stop_counter >= self.args.patience: | |
| self.logger.info(f"Early stopping triggered after {self.early_stop_counter} epochs without improvement") | |
| return True | |
| return False | |