File size: 2,179 Bytes
13bec41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gym
from gym import spaces
from tetris_gym.utils.board_utils import get_heights, get_bumps_from_heights
from agent.utils import calc_holes_array
import numpy as np
class CustomObsWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = spaces.Dict({
"board": env.observation_space["board"],
"piece": env.observation_space["piece"],
"holes_list": spaces.Box(
low=1,
high=env.height,
shape=(env.width,),
dtype=np.uint8,
),
"x": spaces.Discrete(env.width),
"y": spaces.Discrete(env.width),
"piece_shape": spaces.Box(
low=0,
high=1,
shape=(4, 4),
dtype=np.uint8,
),
"empty_above": spaces.Box(
low=0,
high=env.height,
shape=(env.width,),
dtype=np.uint8,
),
"heights": spaces.Box(
low=0,
high=env.height,
shape=(env.width,),
dtype=np.uint8,
),
"bumps": spaces.Box(
low=0,
high=env.height,
shape=(env.width - 1,),
dtype=int,
)
})
def observation(self, obs):
board = obs["board"]
piece = obs["piece"]
heights = get_heights(board)
bumps = get_bumps_from_heights(heights)
holes_array = calc_holes_array(self, board, heights)
empty_above = np.max(heights) - heights[:]
piece_shape = np.zeros((4, 4), dtype=np.uint8)
piece_shape[:len(self.piece), :len(self.piece[0])] = self.piece[:]
obs = {
"board": board,
"x": self.current_pos["x"],
"y": self.current_pos["y"],
"piece_shape": piece_shape,
"piece": piece,
"empty_above": empty_above,
"holes_list": holes_array,
"heights": heights,
"bumps": bumps
}
return obs
|