| import optuna |
| from optuna.samplers import TPESampler |
| from optuna.pruners import MedianPruner |
| import wandb |
| import pandas as pd |
| from model.train import train, init_model, create_dataloaders, ToxicDataset |
| from model.training_config import TrainingConfig |
| from transformers import XLMRobertaTokenizer |
| import json |
| import torch |
|
|
| def load_dataset(file_path: str): |
| """Load and prepare dataset""" |
| df = pd.read_csv(file_path) |
| tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') |
| config = TrainingConfig() |
| return ToxicDataset(df, tokenizer, config) |
|
|
| class HyperparameterTuner: |
| def __init__(self, train_dataset, val_dataset, n_trials=10): |
| self.train_dataset = train_dataset |
| self.val_dataset = val_dataset |
| self.n_trials = n_trials |
| |
| |
| self.study = optuna.create_study( |
| direction="maximize", |
| sampler=TPESampler(seed=42), |
| pruner=MedianPruner( |
| n_startup_trials=2, |
| n_warmup_steps=2, |
| interval_steps=1 |
| ) |
| ) |
|
|
| def objective(self, trial): |
| """Objective function for Optuna optimization with optimal ranges""" |
| |
| config_params = { |
| |
| "model_name": "xlm-roberta-large", |
| "hidden_size": 1024, |
| "num_attention_heads": 16, |
| |
| |
| "lr": trial.suggest_float("lr", 1e-5, 5e-5, log=True), |
| "batch_size": trial.suggest_categorical("batch_size", [32, 64]), |
| "model_dropout": trial.suggest_float("model_dropout", 0.3, 0.45), |
| "weight_decay": trial.suggest_float("weight_decay", 0.01, 0.03), |
| "grad_accum_steps": trial.suggest_int("grad_accum_steps", 1, 4), |
| |
| |
| "epochs": 2, |
| "mixed_precision": "bf16", |
| "max_length": 128, |
| "fp16": False, |
| "distributed": False, |
| "world_size": 1, |
| "num_workers": 12, |
| "activation_checkpointing": True, |
| "tensor_float_32": True, |
| "gc_frequency": 500 |
| } |
|
|
| |
| config = TrainingConfig(**config_params) |
|
|
| |
| wandb.init( |
| project="toxic-classification-hparam-tuning", |
| name=f"trial-{trial.number}", |
| config={ |
| **config_params, |
| 'trial_number': trial.number, |
| 'pruner': str(trial.study.pruner), |
| 'sampler': str(trial.study.sampler) |
| }, |
| reinit=True, |
| tags=['hyperparameter-optimization', f'trial-{trial.number}'] |
| ) |
|
|
| try: |
| |
| model = init_model(config) |
| train_loader, val_loader = create_dataloaders( |
| self.train_dataset, |
| self.val_dataset, |
| config |
| ) |
|
|
| |
| metrics = train(model, train_loader, val_loader, config) |
| |
| |
| wandb.log({ |
| 'final_val_auc': metrics['val/auc'], |
| 'final_val_loss': metrics['val/loss'], |
| 'final_train_loss': metrics['train/loss'], |
| 'peak_gpu_memory': torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, |
| 'trial_completed': True |
| }) |
| |
| |
| trial.report(metrics['val/auc'], step=config.epochs) |
| |
| |
| if trial.should_prune(): |
| wandb.log({'pruned': True}) |
| raise optuna.TrialPruned() |
|
|
| return metrics['val/auc'] |
|
|
| except Exception as e: |
| wandb.log({ |
| 'error': str(e), |
| 'trial_failed': True |
| }) |
| print(f"Trial failed: {str(e)}") |
| raise optuna.TrialPruned() |
|
|
| finally: |
| |
| if 'model' in locals(): |
| del model |
| torch.cuda.empty_cache() |
| wandb.finish() |
|
|
| def run_optimization(self): |
| """Run the hyperparameter optimization""" |
| print("Starting hyperparameter optimization...") |
| print("Search space:") |
| print(" - Learning rate: 1e-5 to 5e-5") |
| print(" - Batch size: [32, 64]") |
| print(" - Dropout: 0.3 to 0.45") |
| print(" - Weight decay: 0.01 to 0.03") |
| print(" - Gradient accumulation steps: 1 to 4") |
| print("\nFixed parameters:") |
| print(" - Hidden size: 1024 (original)") |
| print(" - Attention heads: 16 (original)") |
| |
| try: |
| self.study.optimize( |
| self.objective, |
| n_trials=self.n_trials, |
| timeout=None, |
| callbacks=[self._log_trial] |
| ) |
|
|
| |
| print("\nBest trial:") |
| best_trial = self.study.best_trial |
| print(f" Value: {best_trial.value:.4f}") |
| print(" Params:") |
| for key, value in best_trial.params.items(): |
| print(f" {key}: {value}") |
|
|
| |
| self._save_study_results() |
|
|
| except KeyboardInterrupt: |
| print("\nOptimization interrupted by user.") |
| self._save_study_results() |
| except Exception as e: |
| print(f"Optimization failed: {str(e)}") |
| raise |
|
|
| def _log_trial(self, study, trial): |
| """Callback to log trial results with enhanced metrics""" |
| if trial.value is not None: |
| metrics = { |
| "best_auc": study.best_value, |
| "trial_auc": trial.value, |
| "trial_number": trial.number, |
| **trial.params |
| } |
| |
| |
| if len(study.trials) > 1: |
| metrics.update({ |
| "optimization_progress": { |
| "trials_completed": len(study.trials), |
| "improvement_rate": (study.best_value - study.trials[0].value) / len(study.trials), |
| "best_trial_number": study.best_trial.number |
| } |
| }) |
| |
| wandb.log(metrics) |
|
|
| def _save_study_results(self): |
| """Save optimization results with enhanced metadata""" |
| import joblib |
| from pathlib import Path |
| from datetime import datetime |
| |
| |
| results_dir = Path("optimization_results") |
| results_dir.mkdir(exist_ok=True) |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| study_path = results_dir / f"hparam_optimization_study_{timestamp}.pkl" |
| joblib.dump(self.study, study_path) |
| |
| |
| results = { |
| "best_trial": { |
| "number": self.study.best_trial.number, |
| "value": self.study.best_value, |
| "params": self.study.best_trial.params |
| }, |
| "study_statistics": { |
| "n_trials": len(self.study.trials), |
| "n_completed": len([t for t in self.study.trials if t.state == optuna.trial.TrialState.COMPLETE]), |
| "n_pruned": len([t for t in self.study.trials if t.state == optuna.trial.TrialState.PRUNED]), |
| "datetime_start": self.study.trials[0].datetime_start.isoformat(), |
| "datetime_complete": datetime.now().isoformat() |
| }, |
| "search_space": { |
| "lr": {"low": 1e-5, "high": 5e-5}, |
| "batch_size": [32, 64], |
| "model_dropout": {"low": 0.3, "high": 0.45}, |
| "weight_decay": {"low": 0.01, "high": 0.03}, |
| "grad_accum_steps": {"low": 1, "high": 4} |
| }, |
| "trial_history": [ |
| { |
| "number": t.number, |
| "value": t.value, |
| "state": str(t.state), |
| "params": t.params if hasattr(t, 'params') else None |
| } |
| for t in self.study.trials |
| ] |
| } |
| |
| results_path = results_dir / f"optimization_results_{timestamp}.json" |
| with open(results_path, "w") as f: |
| json.dump(results, f, indent=4) |
| |
| print(f"\nResults saved to:") |
| print(f" - Study: {study_path}") |
| print(f" - Results: {results_path}") |
|
|
| def main(): |
| """Main function to run hyperparameter optimization""" |
| |
| train_dataset = load_dataset("dataset/split/train.csv") |
| val_dataset = load_dataset("dataset/split/val.csv") |
|
|
| |
| tuner = HyperparameterTuner( |
| train_dataset=train_dataset, |
| val_dataset=val_dataset, |
| n_trials=10 |
| ) |
|
|
| |
| tuner.run_optimization() |
|
|
| if __name__ == "__main__": |
| main() |