TetrisAI / observation_wrapper.py
marci0929's picture
Upload with huggingface_hub
13bec41
raw
history blame contribute delete
No virus
2.18 kB
import gym
from gym import spaces
from tetris_gym.utils.board_utils import get_heights, get_bumps_from_heights
from agent.utils import calc_holes_array
import numpy as np
class CustomObsWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = spaces.Dict({
"board": env.observation_space["board"],
"piece": env.observation_space["piece"],
"holes_list": spaces.Box(
low=1,
high=env.height,
shape=(env.width,),
dtype=np.uint8,
),
"x": spaces.Discrete(env.width),
"y": spaces.Discrete(env.width),
"piece_shape": spaces.Box(
low=0,
high=1,
shape=(4, 4),
dtype=np.uint8,
),
"empty_above": spaces.Box(
low=0,
high=env.height,
shape=(env.width,),
dtype=np.uint8,
),
"heights": spaces.Box(
low=0,
high=env.height,
shape=(env.width,),
dtype=np.uint8,
),
"bumps": spaces.Box(
low=0,
high=env.height,
shape=(env.width - 1,),
dtype=int,
)
})
def observation(self, obs):
board = obs["board"]
piece = obs["piece"]
heights = get_heights(board)
bumps = get_bumps_from_heights(heights)
holes_array = calc_holes_array(self, board, heights)
empty_above = np.max(heights) - heights[:]
piece_shape = np.zeros((4, 4), dtype=np.uint8)
piece_shape[:len(self.piece), :len(self.piece[0])] = self.piece[:]
obs = {
"board": board,
"x": self.current_pos["x"],
"y": self.current_pos["y"],
"piece_shape": piece_shape,
"piece": piece,
"empty_above": empty_above,
"holes_list": holes_array,
"heights": heights,
"bumps": bumps
}
return obs