Spaces:
Runtime error
Runtime error
from stable_baselines3.common.callbacks import BaseCallback | |
from stable_baselines3.common.logger import HParam | |
class HParamCallback(BaseCallback): | |
def __init__(self): | |
super().__init__() | |
def _on_training_start(self) -> None: | |
hparam_dict = { | |
"algorithm": self.model.__class__.__name__, | |
"learning rate": self.model.learning_rate, | |
"steps_per_iteration": self.model.n_steps * self.model.n_envs, | |
"batch_size": self.model.batch_size, | |
"optim_epochs_per_iteration": self.model.n_epochs, | |
"gamma": self.model.gamma, | |
"gae_lambda": self.model.gae_lambda, | |
"ent_coef": self.model.ent_coef, | |
"vf_coef": self.model.vf_coef, | |
} | |
metric_dict = { | |
"eval/mean_reward": 0, | |
"train/loss": 0, | |
} | |
self.logger.record( | |
"hparams", | |
HParam(hparam_dict, metric_dict), | |
exclude=("stdout", "log", "json", "csv"), | |
) | |
def _on_step(self) -> bool: | |
return True | |