Upload with huggingface_hub
Browse files- agent.py +13 -0
- my_model.zip +3 -0
- observation_wrapper.py +74 -0
- reward_wrapper.py +66 -0
- utils.py +8 -0
agent.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from stable_baselines3 import A2C
|
2 |
+
from agent.observation_wrapper import CustomObsWrapper
|
3 |
+
|
4 |
+
class Agent:
|
5 |
+
|
6 |
+
def __init__(self, env) -> None:
|
7 |
+
self.model = A2C.load("agent/my_model")
|
8 |
+
self.observation_wrapper = CustomObsWrapper(env)
|
9 |
+
|
10 |
+
def act(self, observation):
|
11 |
+
extended_obsetvation = self.observation_wrapper.observation(observation)
|
12 |
+
|
13 |
+
return self.model.predict(extended_obsetvation, deterministic=True)
|
my_model.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2f023b0292ff0d225d43e005826d45ce4e0f24ef202bbc1ba08e6f1960ffcc8
|
3 |
+
size 2400942
|
observation_wrapper.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
from gym import spaces
|
3 |
+
from tetris_gym.utils.board_utils import get_heights, get_bumps_from_heights
|
4 |
+
from agent.utils import calc_holes_array
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class CustomObsWrapper(gym.ObservationWrapper):
|
10 |
+
def __init__(self, env):
|
11 |
+
super().__init__(env)
|
12 |
+
self.observation_space = spaces.Dict({
|
13 |
+
"board": env.observation_space["board"],
|
14 |
+
"piece": env.observation_space["piece"],
|
15 |
+
"holes_list": spaces.Box(
|
16 |
+
low=1,
|
17 |
+
high=env.height,
|
18 |
+
shape=(env.width,),
|
19 |
+
dtype=np.uint8,
|
20 |
+
),
|
21 |
+
"x": spaces.Discrete(env.width),
|
22 |
+
"y": spaces.Discrete(env.width),
|
23 |
+
"piece_shape": spaces.Box(
|
24 |
+
low=0,
|
25 |
+
high=1,
|
26 |
+
shape=(4, 4),
|
27 |
+
dtype=np.uint8,
|
28 |
+
),
|
29 |
+
"empty_above": spaces.Box(
|
30 |
+
low=0,
|
31 |
+
high=env.height,
|
32 |
+
shape=(env.width,),
|
33 |
+
dtype=np.uint8,
|
34 |
+
),
|
35 |
+
"heights": spaces.Box(
|
36 |
+
low=0,
|
37 |
+
high=env.height,
|
38 |
+
shape=(env.width,),
|
39 |
+
dtype=np.uint8,
|
40 |
+
),
|
41 |
+
"bumps": spaces.Box(
|
42 |
+
low=0,
|
43 |
+
high=env.height,
|
44 |
+
shape=(env.width - 1,),
|
45 |
+
dtype=int,
|
46 |
+
)
|
47 |
+
})
|
48 |
+
|
49 |
+
def observation(self, obs):
|
50 |
+
board = obs["board"]
|
51 |
+
piece = obs["piece"]
|
52 |
+
|
53 |
+
heights = get_heights(board)
|
54 |
+
bumps = get_bumps_from_heights(heights)
|
55 |
+
holes_array = calc_holes_array(self, board, heights)
|
56 |
+
empty_above = np.max(heights) - heights[:]
|
57 |
+
piece_shape = np.zeros((4, 4), dtype=np.uint8)
|
58 |
+
piece_shape[:len(self.piece), :len(self.piece[0])] = self.piece[:]
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
obs = {
|
63 |
+
"board": board,
|
64 |
+
"x": self.current_pos["x"],
|
65 |
+
"y": self.current_pos["y"],
|
66 |
+
"piece_shape": piece_shape,
|
67 |
+
"piece": piece,
|
68 |
+
"empty_above": empty_above,
|
69 |
+
"holes_list": holes_array,
|
70 |
+
"heights": heights,
|
71 |
+
"bumps": bumps
|
72 |
+
}
|
73 |
+
|
74 |
+
return obs
|
reward_wrapper.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import sqrt
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class CustomRewardWrapper(gym.Wrapper):
|
8 |
+
|
9 |
+
def __init__(self, env):
|
10 |
+
super().__init__(env)
|
11 |
+
self.prev_max_height = 0
|
12 |
+
self.prev_cleared = 0
|
13 |
+
self.prev_score = 0
|
14 |
+
self.prev_holes = 0
|
15 |
+
|
16 |
+
def step(self, action):
|
17 |
+
obs, reward, done, info = self.env.step(action)
|
18 |
+
board = obs["board"]
|
19 |
+
heights = obs["heights"]
|
20 |
+
|
21 |
+
# # Default reward
|
22 |
+
reward = 2
|
23 |
+
# # reward = ((self.height - max(heights)) / self.height)
|
24 |
+
# # reward += np.sum(board)
|
25 |
+
# reward = (self.height - max(heights)) / self.height
|
26 |
+
# reward += 2
|
27 |
+
# #
|
28 |
+
# # reward = (self.score - self.prev_score) + 1
|
29 |
+
# # self.prev_score = self.score
|
30 |
+
# #
|
31 |
+
# # # if max(heights) < self.prev_max_height:
|
32 |
+
# # reward += (self.prev_max_height - max(heights))
|
33 |
+
# # self.prev_max_height = max(heights)
|
34 |
+
# #
|
35 |
+
# reward += self.cleared_lines
|
36 |
+
reward += (self.cleared_lines - self.prev_cleared) ** 3
|
37 |
+
#
|
38 |
+
# # Penalty for big differences between columns
|
39 |
+
reward -= self.get_bumpiness_and_height(board)[0] / self.height
|
40 |
+
#
|
41 |
+
# # Penalty for holes
|
42 |
+
# # holes_val = 0
|
43 |
+
# # for col_num in range(self.width):
|
44 |
+
# # col_value = 0
|
45 |
+
# # for row_num in range(self.height - 1, self.height - 1 - heights[col_num], -1):
|
46 |
+
# # col_value += 1 if board[row_num][col_num] == 1 else -(row_num / self.width)
|
47 |
+
# # holes_val += col_value / (1 + heights[col_num])
|
48 |
+
#
|
49 |
+
holes = self.get_holes(board)
|
50 |
+
reward -= (holes - self.prev_holes) * 0.8
|
51 |
+
|
52 |
+
|
53 |
+
# reward = 1 + ((self.cleared_lines - self.prev_cleared) ** 2) * self.width
|
54 |
+
|
55 |
+
self.prev_max_height = np.max(heights)
|
56 |
+
self.prev_cleared = self.cleared_lines
|
57 |
+
self.prev_score = self.score
|
58 |
+
self.prev_holes = holes
|
59 |
+
|
60 |
+
if self.gameover:
|
61 |
+
self.prev_max_height = 0
|
62 |
+
self.prev_cleared = 0
|
63 |
+
self.prev_score = 0
|
64 |
+
self.prev_holes = 0
|
65 |
+
|
66 |
+
return obs, reward, done, info
|
utils.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def calc_holes_array(self, board, heights):
|
2 |
+
holes_list = []
|
3 |
+
for col_num in range(self.width):
|
4 |
+
col_value = 0
|
5 |
+
for row_num in range(self.height - 1, self.height - 1 - heights[col_num], -1):
|
6 |
+
col_value += 0 if board[row_num][col_num] == 1 else 1
|
7 |
+
holes_list.append(col_value)
|
8 |
+
return holes_list
|