File size: 1,943 Bytes
9dc837c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import gym
import numpy as np
from typing import Optional, Tuple, Union
ObsType = Union[np.ndarray, dict]
ActType = Union[int, float, np.ndarray, dict]
class NoRewardTimeout(gym.Wrapper):
def __init__(
self, env: gym.Env, n_timeout_steps: int, n_fire_steps: Optional[int] = None
) -> None:
super().__init__(env)
self.n_timeout_steps = n_timeout_steps
self.n_fire_steps = n_fire_steps
self.fire_act = None
if n_fire_steps is not None:
action_meanings = env.unwrapped.get_action_meanings()
assert "FIRE" in action_meanings
self.fire_act = action_meanings.index("FIRE")
self.steps_since_reward = 0
self.episode_score = 0
self.episode_step_idx = 0
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
if self.steps_since_reward == self.n_fire_steps:
assert self.fire_act is not None
self.print_intervention("Force fire action")
action = self.fire_act
obs, rew, done, info = self.env.step(action)
self.episode_score += rew
self.episode_step_idx += 1
if rew != 0 or done:
self.steps_since_reward = 0
else:
self.steps_since_reward += 1
if self.steps_since_reward >= self.n_timeout_steps:
self.print_intervention("Early terminate")
done = True
return obs, rew, done, info
def reset(self, **kwargs) -> ObsType:
self._reset_state()
return self.env.reset(**kwargs)
def _reset_state(self) -> None:
self.steps_since_reward = 0
self.episode_score = 0
self.episode_step_idx = 0
def print_intervention(self, tag: str) -> None:
print(
f"{self.__class__.__name__}: {tag} | "
f"Score: {self.episode_score} | "
f"Length: {self.episode_step_idx}"
)
|