File size: 1,395 Bytes
0fdb130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import importlib.util
from typing import TYPE_CHECKING

from .utils import BestRun


if TYPE_CHECKING:
    from .trainer import Trainer


def is_optuna_available() -> bool:
    return importlib.util.find_spec("optuna") is not None


def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"


def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun:
    import optuna

    # Heavily inspired by transformers.integrations.run_hp_search_optuna
    # https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/integrations.py#L160
    def _objective(trial):
        trainer.objective = None
        trainer.train(trial=trial)
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params, study)