a2c-MountainCar-v0 / wrappers /initial_step_truncate_wrapper.py
sgoodfriend's picture
A2C playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
0bbce05
raw
history blame
950 Bytes
import gym
import numpy as np
from typing import Any, Dict, Tuple, Union
from wrappers.vectorable_wrapper import VecotarableWrapper
ObsType = Union[np.ndarray, dict]
ActType = Union[int, float, np.ndarray, dict]
class InitialStepTruncateWrapper(VecotarableWrapper):
def __init__(self, env: gym.Env, initial_steps_to_truncate: int) -> None:
super().__init__(env)
self.initial_steps_to_truncate = initial_steps_to_truncate
self.initialized = initial_steps_to_truncate == 0
self.steps = 0
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
obs, rew, done, info = self.env.step(action)
if not self.initialized:
self.steps += 1
if self.steps >= self.initial_steps_to_truncate:
print(f"Truncation at {self.steps} steps")
done = True
self.initialized = True
return obs, rew, done, info