Spaces:
Sleeping
Sleeping
File size: 4,722 Bytes
be5548b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from gym_minigrid.minigrid import *
from gym_minigrid.register import register
class MemoryEnv(MiniGridEnv):
"""
This environment is a memory test. The agent starts in a small room
where it sees an object. It then has to go through a narrow hallway
which ends in a split. At each end of the split there is an object,
one of which is the same as the object in the starting room. The
agent has to remember the initial object, and go to the matching
object at split.
"""
def __init__(
self,
seed,
size=8,
random_length=False,
):
self.random_length = random_length
super().__init__(
seed=seed,
grid_size=size,
max_steps=5*size**2,
# Set this to True for maximum speed
see_through_walls=False,
)
def _gen_grid(self, width, height):
self.grid = Grid(width, height)
# Generate the surrounding walls
self.grid.horz_wall(0, 0)
self.grid.horz_wall(0, height-1)
self.grid.vert_wall(0, 0)
self.grid.vert_wall(width - 1, 0)
assert height % 2 == 1
upper_room_wall = height // 2 - 2
lower_room_wall = height // 2 + 2
if self.random_length:
hallway_end = self._rand_int(4, width - 2)
else:
hallway_end = width - 3
# Start room
for i in range(1, 5):
self.grid.set(i, upper_room_wall, Wall())
self.grid.set(i, lower_room_wall, Wall())
self.grid.set(4, upper_room_wall + 1, Wall())
self.grid.set(4, lower_room_wall - 1, Wall())
# Horizontal hallway
for i in range(5, hallway_end):
self.grid.set(i, upper_room_wall + 1, Wall())
self.grid.set(i, lower_room_wall - 1, Wall())
# Vertical hallway
for j in range(0, height):
if j != height // 2:
self.grid.set(hallway_end, j, Wall())
self.grid.set(hallway_end + 2, j, Wall())
# Fix the player's start position and orientation
self.agent_pos = (self._rand_int(1, hallway_end + 1), height // 2)
self.agent_dir = 0
# Place objects
start_room_obj = self._rand_elem([Key, Ball])
self.grid.set(1, height // 2 - 1, start_room_obj('green'))
other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
pos0 = (hallway_end + 1, height // 2 - 2)
pos1 = (hallway_end + 1, height // 2 + 2)
self.grid.set(*pos0, other_objs[0]('green'))
self.grid.set(*pos1, other_objs[1]('green'))
# Choose the target objects
if start_room_obj == other_objs[0]:
self.success_pos = (pos0[0], pos0[1] + 1)
self.failure_pos = (pos1[0], pos1[1] - 1)
else:
self.success_pos = (pos1[0], pos1[1] - 1)
self.failure_pos = (pos0[0], pos0[1] + 1)
self.mission = 'go to the matching object at the end of the hallway'
def step(self, action):
if action == MiniGridEnv.Actions.pickup:
action = MiniGridEnv.Actions.toggle
obs, reward, done, info = MiniGridEnv.step(self, action)
if tuple(self.agent_pos) == self.success_pos:
reward = self._reward()
done = True
if tuple(self.agent_pos) == self.failure_pos:
reward = 0
done = True
return obs, reward, done, info
class MemoryS17Random(MemoryEnv):
def __init__(self, seed=None):
super().__init__(seed=seed, size=17, random_length=True)
register(
id='MiniGrid-MemoryS17Random-v0',
entry_point='gym_minigrid.envs:MemoryS17Random',
)
class MemoryS13Random(MemoryEnv):
def __init__(self, seed=None):
super().__init__(seed=seed, size=13, random_length=True)
register(
id='MiniGrid-MemoryS13Random-v0',
entry_point='gym_minigrid.envs:MemoryS13Random',
)
class MemoryS13(MemoryEnv):
def __init__(self, seed=None):
super().__init__(seed=seed, size=13)
register(
id='MiniGrid-MemoryS13-v0',
entry_point='gym_minigrid.envs:MemoryS13',
)
class MemoryS11(MemoryEnv):
def __init__(self, seed=None):
super().__init__(seed=seed, size=11)
register(
id='MiniGrid-MemoryS11-v0',
entry_point='gym_minigrid.envs:MemoryS11',
)
class MemoryS9(MemoryEnv):
def __init__(self, seed=None):
super().__init__(seed=seed, size=9)
register(
id='MiniGrid-MemoryS9-v0',
entry_point='gym_minigrid.envs:MemoryS9',
)
class MemoryS7(MemoryEnv):
def __init__(self, seed=None):
super().__init__(seed=seed, size=7)
register(
id='MiniGrid-MemoryS7-v0',
entry_point='gym_minigrid.envs:MemoryS7',
)
|