fixed bug related to dynamic ranges in dictionary with 'min' and 'max' value mismatch in optuna suggest fn
76101b4
verified
| import os | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from .imports import * | |
| from .model import GeneformerMultiTask | |
| from .utils import calculate_task_specific_metrics, get_layer_freeze_range | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def initialize_wandb(config): | |
| if config.get("use_wandb", False): | |
| import wandb | |
| wandb.init(project=config["wandb_project"], config=config) | |
| print("Weights & Biases (wandb) initialized and will be used for logging.") | |
| else: | |
| print( | |
| "Weights & Biases (wandb) is not enabled. Logging will use other methods." | |
| ) | |
| def create_model(config, num_labels_list, device): | |
| model = GeneformerMultiTask( | |
| config["pretrained_path"], | |
| num_labels_list, | |
| dropout_rate=config["dropout_rate"], | |
| use_task_weights=config["use_task_weights"], | |
| task_weights=config["task_weights"], | |
| max_layers_to_freeze=config["max_layers_to_freeze"], | |
| use_attention_pooling=config["use_attention_pooling"], | |
| ) | |
| if config["use_data_parallel"]: | |
| model = nn.DataParallel(model) | |
| return model.to(device) | |
| def setup_optimizer_and_scheduler(model, config, total_steps): | |
| optimizer = AdamW( | |
| model.parameters(), | |
| lr=config["learning_rate"], | |
| weight_decay=config["weight_decay"], | |
| ) | |
| warmup_steps = int(config["warmup_ratio"] * total_steps) | |
| if config["lr_scheduler_type"] == "linear": | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps | |
| ) | |
| elif config["lr_scheduler_type"] == "cosine": | |
| scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=total_steps, | |
| num_cycles=0.5, | |
| ) | |
| return optimizer, scheduler | |
| def train_epoch( | |
| model, train_loader, optimizer, scheduler, device, config, writer, epoch | |
| ): | |
| model.train() | |
| progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") | |
| for batch_idx, batch in enumerate(progress_bar): | |
| optimizer.zero_grad() | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = [ | |
| batch["labels"][task_name].to(device) for task_name in config["task_names"] | |
| ] | |
| loss, _, _ = model(input_ids, attention_mask, labels) | |
| loss.backward() | |
| if config["gradient_clipping"]: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) | |
| optimizer.step() | |
| scheduler.step() | |
| writer.add_scalar( | |
| "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx | |
| ) | |
| if config.get("use_wandb", False): | |
| import wandb | |
| wandb.log({"Training Loss": loss.item()}) | |
| # Update progress bar | |
| progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) | |
| return loss.item() # Return the last batch loss | |
| def validate_model(model, val_loader, device, config): | |
| model.eval() | |
| val_loss = 0.0 | |
| task_true_labels = {task_name: [] for task_name in config["task_names"]} | |
| task_pred_labels = {task_name: [] for task_name in config["task_names"]} | |
| task_pred_probs = {task_name: [] for task_name in config["task_names"]} | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = [ | |
| batch["labels"][task_name].to(device) | |
| for task_name in config["task_names"] | |
| ] | |
| loss, logits, _ = model(input_ids, attention_mask, labels) | |
| val_loss += loss.item() | |
| for sample_idx in range(len(batch["input_ids"])): | |
| for i, task_name in enumerate(config["task_names"]): | |
| true_label = batch["labels"][task_name][sample_idx].item() | |
| pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() | |
| pred_prob = ( | |
| torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() | |
| ) | |
| task_true_labels[task_name].append(true_label) | |
| task_pred_labels[task_name].append(pred_label) | |
| task_pred_probs[task_name].append(pred_prob) | |
| val_loss /= len(val_loader) | |
| return val_loss, task_true_labels, task_pred_labels, task_pred_probs | |
| def log_metrics(task_metrics, val_loss, config, writer, epochs): | |
| for task_name, metrics in task_metrics.items(): | |
| print( | |
| f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}" | |
| ) | |
| if config.get("use_wandb", False): | |
| import wandb | |
| wandb.log( | |
| { | |
| f"{task_name} Validation F1 Macro": metrics["f1"], | |
| f"{task_name} Validation Accuracy": metrics["accuracy"], | |
| } | |
| ) | |
| writer.add_scalar("Validation Loss", val_loss, epochs) | |
| for task_name, metrics in task_metrics.items(): | |
| writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs) | |
| writer.add_scalar( | |
| f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs | |
| ) | |
| def save_validation_predictions( | |
| val_cell_id_mapping, | |
| task_true_labels, | |
| task_pred_labels, | |
| task_pred_probs, | |
| config, | |
| trial_number=None, | |
| ): | |
| if trial_number is not None: | |
| trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}") | |
| os.makedirs(trial_results_dir, exist_ok=True) | |
| val_preds_file = os.path.join(trial_results_dir, "val_preds.csv") | |
| else: | |
| val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") | |
| rows = [] | |
| for sample_idx in range(len(val_cell_id_mapping)): | |
| row = {"Cell ID": val_cell_id_mapping[sample_idx]} | |
| for task_name in config["task_names"]: | |
| row[f"{task_name} True"] = task_true_labels[task_name][sample_idx] | |
| row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx] | |
| row[f"{task_name} Probabilities"] = ",".join( | |
| map(str, task_pred_probs[task_name][sample_idx]) | |
| ) | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| df.to_csv(val_preds_file, index=False) | |
| print(f"Validation predictions saved to {val_preds_file}") | |
| def train_model( | |
| config, | |
| device, | |
| train_loader, | |
| val_loader, | |
| train_cell_id_mapping, | |
| val_cell_id_mapping, | |
| num_labels_list, | |
| ): | |
| set_seed(config["seed"]) | |
| initialize_wandb(config) | |
| model = create_model(config, num_labels_list, device) | |
| total_steps = len(train_loader) * config["epochs"] | |
| optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) | |
| log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run") | |
| writer = SummaryWriter(log_dir=log_dir) | |
| epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress") | |
| for epoch in epoch_progress: | |
| last_loss = train_epoch( | |
| model, train_loader, optimizer, scheduler, device, config, writer, epoch | |
| ) | |
| epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"}) | |
| val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( | |
| model, val_loader, device, config | |
| ) | |
| task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) | |
| log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) | |
| writer.close() | |
| save_validation_predictions( | |
| val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config | |
| ) | |
| if config.get("use_wandb", False): | |
| import wandb | |
| wandb.finish() | |
| print(f"\nFinal Validation Loss: {val_loss:.4f}") | |
| return val_loss, model # Return both the validation loss and the trained model | |
| def objective( | |
| trial, | |
| train_loader, | |
| val_loader, | |
| train_cell_id_mapping, | |
| val_cell_id_mapping, | |
| num_labels_list, | |
| config, | |
| device, | |
| ): | |
| set_seed(config["seed"]) # Set the seed before each trial | |
| initialize_wandb(config) | |
| # Hyperparameters | |
| config["learning_rate"] = trial.suggest_float( | |
| "learning_rate", | |
| config["hyperparameters"]["learning_rate"]["low"], | |
| config["hyperparameters"]["learning_rate"]["high"], | |
| log=config["hyperparameters"]["learning_rate"]["log"], | |
| ) | |
| config["warmup_ratio"] = trial.suggest_float( | |
| "warmup_ratio", | |
| config["hyperparameters"]["warmup_ratio"]["low"], | |
| config["hyperparameters"]["warmup_ratio"]["high"], | |
| ) | |
| config["weight_decay"] = trial.suggest_float( | |
| "weight_decay", | |
| config["hyperparameters"]["weight_decay"]["low"], | |
| config["hyperparameters"]["weight_decay"]["high"], | |
| ) | |
| config["dropout_rate"] = trial.suggest_float( | |
| "dropout_rate", | |
| config["hyperparameters"]["dropout_rate"]["low"], | |
| config["hyperparameters"]["dropout_rate"]["high"], | |
| ) | |
| config["lr_scheduler_type"] = trial.suggest_categorical( | |
| "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"] | |
| ) | |
| config["use_attention_pooling"] = trial.suggest_categorical( | |
| "use_attention_pooling", [False] | |
| ) | |
| if config["use_task_weights"]: | |
| config["task_weights"] = [ | |
| trial.suggest_float( | |
| f"task_weight_{i}", | |
| config["hyperparameters"]["task_weights"]["low"], | |
| config["hyperparameters"]["task_weights"]["high"], | |
| ) | |
| for i in range(len(num_labels_list)) | |
| ] | |
| weight_sum = sum(config["task_weights"]) | |
| config["task_weights"] = [ | |
| weight / weight_sum for weight in config["task_weights"] | |
| ] | |
| else: | |
| config["task_weights"] = None | |
| # Dynamic range for max_layers_to_freeze | |
| freeze_range = get_layer_freeze_range(config["pretrained_path"]) | |
| config["max_layers_to_freeze"] = trial.suggest_int( | |
| "max_layers_to_freeze", | |
| freeze_range["min"], | |
| freeze_range["max"] | |
| ) | |
| model = create_model(config, num_labels_list, device) | |
| total_steps = len(train_loader) * config["epochs"] | |
| optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) | |
| log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}") | |
| writer = SummaryWriter(log_dir=log_dir) | |
| for epoch in range(config["epochs"]): | |
| train_epoch( | |
| model, train_loader, optimizer, scheduler, device, config, writer, epoch | |
| ) | |
| val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model( | |
| model, val_loader, device, config | |
| ) | |
| task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) | |
| log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) | |
| writer.close() | |
| save_validation_predictions( | |
| val_cell_id_mapping, | |
| task_true_labels, | |
| task_pred_labels, | |
| task_pred_probs, | |
| config, | |
| trial.number, | |
| ) | |
| trial.set_user_attr("model_state_dict", model.state_dict()) | |
| trial.set_user_attr("task_weights", config["task_weights"]) | |
| trial.report(val_loss, config["epochs"]) | |
| if trial.should_prune(): | |
| raise optuna.TrialPruned() | |
| if config.get("use_wandb", False): | |
| import wandb | |
| wandb.log( | |
| { | |
| "trial_number": trial.number, | |
| "val_loss": val_loss, | |
| **{ | |
| f"{task_name}_f1": metrics["f1"] | |
| for task_name, metrics in task_metrics.items() | |
| }, | |
| **{ | |
| f"{task_name}_accuracy": metrics["accuracy"] | |
| for task_name, metrics in task_metrics.items() | |
| }, | |
| **{ | |
| k: v | |
| for k, v in config.items() | |
| if k | |
| in [ | |
| "learning_rate", | |
| "warmup_ratio", | |
| "weight_decay", | |
| "dropout_rate", | |
| "lr_scheduler_type", | |
| "use_attention_pooling", | |
| "max_layers_to_freeze", | |
| ] | |
| }, | |
| } | |
| ) | |
| wandb.finish() | |
| return val_loss | |