File size: 2,179 Bytes
13bec41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gym
from gym import spaces
from tetris_gym.utils.board_utils import get_heights, get_bumps_from_heights
from agent.utils import calc_holes_array

import numpy as np


class CustomObsWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = spaces.Dict({
            "board": env.observation_space["board"],
            "piece": env.observation_space["piece"],
            "holes_list": spaces.Box(
                low=1,
                high=env.height,
                shape=(env.width,),
                dtype=np.uint8,
            ),
            "x": spaces.Discrete(env.width),
            "y": spaces.Discrete(env.width),
            "piece_shape": spaces.Box(
                low=0,
                high=1,
                shape=(4, 4),
                dtype=np.uint8,
            ),
            "empty_above": spaces.Box(
                low=0,
                high=env.height,
                shape=(env.width,),
                dtype=np.uint8,
            ),
            "heights": spaces.Box(
                low=0,
                high=env.height,
                shape=(env.width,),
                dtype=np.uint8,
            ),
            "bumps": spaces.Box(
                low=0,
                high=env.height,
                shape=(env.width - 1,),
                dtype=int,
            )
        })

    def observation(self, obs):
        board = obs["board"]
        piece = obs["piece"]

        heights = get_heights(board)
        bumps = get_bumps_from_heights(heights)
        holes_array = calc_holes_array(self, board, heights)
        empty_above = np.max(heights) - heights[:]
        piece_shape = np.zeros((4, 4), dtype=np.uint8)
        piece_shape[:len(self.piece), :len(self.piece[0])] = self.piece[:]



        obs = {
            "board": board,
            "x": self.current_pos["x"],
            "y": self.current_pos["y"],
            "piece_shape": piece_shape,
            "piece": piece,
            "empty_above": empty_above,
            "holes_list": holes_array,
            "heights": heights,
            "bumps": bumps
        }

        return obs