A2C playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
8edc5d6
| import numpy as np | |
| from rl_algo_impls.runner.config import Config | |
| from rl_algo_impls.shared.callbacks import Callback | |
| from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv | |
| class MicrortsRewardDecayCallback(Callback): | |
| def __init__( | |
| self, | |
| config: Config, | |
| env: VecEnv, | |
| start_timesteps: int = 0, | |
| ) -> None: | |
| super().__init__() | |
| from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv | |
| unwrapped = env.unwrapped | |
| assert isinstance(unwrapped, MicroRTSGridModeVecEnv) | |
| self.microrts_env = unwrapped | |
| self.base_reward_weights = self.microrts_env.reward_weight | |
| self.total_train_timesteps = config.n_timesteps | |
| self.timesteps_elapsed = start_timesteps | |
| def on_step(self, timesteps_elapsed: int = 1) -> bool: | |
| super().on_step(timesteps_elapsed) | |
| progress = self.timesteps_elapsed / self.total_train_timesteps | |
| # Decay all rewards except WinLoss | |
| reward_weights = self.base_reward_weights * np.array( | |
| [1] + [1 - progress] * (len(self.base_reward_weights) - 1) | |
| ) | |
| self.microrts_env.reward_weight = reward_weights | |
| return True | |