from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.logger import HParam class HParamCallback(BaseCallback): def __init__(self): super().__init__() def _on_training_start(self) -> None: hparam_dict = { "algorithm": self.model.__class__.__name__, "learning rate": self.model.learning_rate, "steps_per_iteration": self.model.n_steps * self.model.n_envs, "batch_size": self.model.batch_size, "optim_epochs_per_iteration": self.model.n_epochs, "gamma": self.model.gamma, "gae_lambda": self.model.gae_lambda, "ent_coef": self.model.ent_coef, "vf_coef": self.model.vf_coef, } metric_dict = { "eval/mean_reward": 0, "train/loss": 0, } self.logger.record( "hparams", HParam(hparam_dict, metric_dict), exclude=("stdout", "log", "json", "csv"), ) def _on_step(self) -> bool: return True