File size: 2,953 Bytes
ee5d423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gym
import numpy as np

from typing import Any, Dict, Tuple, Union

ObsType = Union[np.ndarray, dict]
ActType = Union[int, float, np.ndarray, dict]


class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env: gym.Env, training: bool = True, noop_act: int = 0) -> None:
        super().__init__(env)
        self.training = training
        self.noop_act = noop_act
        self.life_done_continue = False
        self.lives = 0

    def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
        obs, rew, done, info = self.env.step(action)
        new_lives = self.env.unwrapped.ale.lives()
        self.life_done_continue = new_lives != self.lives and not done
        # Only if training should life-end be marked as done
        if self.training and 0 < new_lives < self.lives:
            done = True
        self.lives = new_lives
        return obs, rew, done, info

    def reset(self, **kwargs) -> ObsType:
        # If life_done_continue (but not game over), then a reset should just allow the
        # game to progress to the next life.
        if self.training and self.life_done_continue:
            obs, _, _, _ = self.env.step(self.noop_act)
        else:
            obs = self.env.reset(**kwargs)
        self.lives = self.env.unwrapped.ale.lives()
        return obs


class FireOnLifeStarttEnv(gym.Wrapper):
    def __init__(self, env: gym.Env, fire_act: int = 1) -> None:
        super().__init__(env)
        self.fire_act = fire_act
        action_meanings = env.unwrapped.get_action_meanings()
        assert action_meanings[fire_act] == "FIRE"
        assert len(action_meanings) >= 3
        self.lives = 0
        self.fire_on_next_step = True

    def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
        if self.fire_on_next_step:
            action = self.fire_act
            self.fire_on_next_step = False
        obs, rew, done, info = self.env.step(action)
        new_lives = self.env.unwrapped.ale.lives()
        if 0 < new_lives < self.lives and not done:
            self.fire_on_next_step = True
        self.lives = new_lives
        return obs, rew, done, info

    def reset(self, **kwargs) -> ObsType:
        self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(self.fire_act)
        if done:
            self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset(**kwargs)
        self.fire_on_next_step = False
        return obs


class ClipRewardEnv(gym.Wrapper):
    def __init__(self, env: gym.Env, training: bool = True) -> None:
        super().__init__(env)
        self.training = training

    def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
        obs, rew, done, info = self.env.step(action)
        if self.training:
            info["unclipped_reward"] = rew
            rew = np.sign(rew)
        return obs, rew, done, info