Spaces:
Runtime error
Runtime error
File size: 1,063 Bytes
a257639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
class HParamCallback(BaseCallback):
def __init__(self):
super().__init__()
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"steps_per_iteration": self.model.n_steps * self.model.n_envs,
"batch_size": self.model.batch_size,
"optim_epochs_per_iteration": self.model.n_epochs,
"gamma": self.model.gamma,
"gae_lambda": self.model.gae_lambda,
"ent_coef": self.model.ent_coef,
"vf_coef": self.model.vf_coef,
}
metric_dict = {
"eval/mean_reward": 0,
"train/loss": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
|