ShiyuHuang
commited on
Commit
•
2322e9b
1
Parent(s):
7da2e8e
Upload folder using huggingface_hub
Browse files- goal_keeper.py +1001 -0
- openrl_policy.py +446 -0
- openrl_utils.py +421 -0
- submission.py +81 -0
goal_keeper.py
ADDED
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright 2023 The OpenRL Authors.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
# original code from https://github.com/Sarvar-Anvarov/Google-Research-Football/blob/main/gfootball.py
|
18 |
+
# modified by TARTRL team
|
19 |
+
|
20 |
+
import math
|
21 |
+
import random
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from functools import wraps
|
25 |
+
from enum import Enum
|
26 |
+
from typing import *
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
class Action(Enum):
|
31 |
+
Idle = 0
|
32 |
+
Left = 1
|
33 |
+
TopLeft = 2
|
34 |
+
Top = 3
|
35 |
+
TopRight = 4
|
36 |
+
Right = 5
|
37 |
+
BottomRight = 6
|
38 |
+
Bottom = 7
|
39 |
+
BottomLeft = 8
|
40 |
+
LongPass= 9
|
41 |
+
HighPass = 10
|
42 |
+
ShortPass = 11
|
43 |
+
Shot = 12
|
44 |
+
Sprint = 13
|
45 |
+
ReleaseDirection = 14
|
46 |
+
ReleaseSprint = 15
|
47 |
+
Slide = 16
|
48 |
+
Dribble = 17
|
49 |
+
ReleaseDribble = 18
|
50 |
+
|
51 |
+
|
52 |
+
ALL_DIRECTION_ACTIONS = [Action.Left, Action.TopLeft, Action.Top, Action.TopRight, Action.Right, Action.BottomRight, Action.Bottom, Action.BottomLeft]
|
53 |
+
ALL_DIRECTION_VECS = [(-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]
|
54 |
+
|
55 |
+
sticky_index_to_action = [
|
56 |
+
Action.Left,
|
57 |
+
Action.TopLeft,
|
58 |
+
Action.Top,
|
59 |
+
Action.TopRight,
|
60 |
+
Action.Right,
|
61 |
+
Action.BottomRight,
|
62 |
+
Action.Bottom,
|
63 |
+
Action.BottomLeft,
|
64 |
+
Action.Sprint,
|
65 |
+
Action.Dribble
|
66 |
+
]
|
67 |
+
|
68 |
+
GOAL_BIAS = 0.01
|
69 |
+
|
70 |
+
class PlayerRole(Enum):
|
71 |
+
GoalKeeper = 0
|
72 |
+
CenterBack = 1
|
73 |
+
LeftBack = 2
|
74 |
+
RightBack = 3
|
75 |
+
DefenceMidfield = 4
|
76 |
+
CentralMidfield = 5
|
77 |
+
LeftMidfield = 6
|
78 |
+
RIghtMidfield = 7
|
79 |
+
AttackMidfield = 8
|
80 |
+
CentralFront = 9
|
81 |
+
|
82 |
+
|
83 |
+
class GameMode(Enum):
|
84 |
+
Normal = 0
|
85 |
+
KickOff = 1
|
86 |
+
GoalKick = 2
|
87 |
+
FreeKick = 3
|
88 |
+
Corner = 4
|
89 |
+
ThrowIn = 5
|
90 |
+
Penalty = 6
|
91 |
+
|
92 |
+
|
93 |
+
def human_readable_agent(agent: Callable[[Dict], Action]):
|
94 |
+
"""
|
95 |
+
Decorator allowing for more human-friendly implementation of the agent function.
|
96 |
+
@human_readable_agent
|
97 |
+
def my_agent(obs):
|
98 |
+
...
|
99 |
+
return football_action_set.action_right
|
100 |
+
"""
|
101 |
+
@wraps(agent)
|
102 |
+
def agent_wrapper(obs) -> List[int]:
|
103 |
+
# Extract observations for the first (and only) player we control.
|
104 |
+
# obs = obs['players_raw'][0]
|
105 |
+
# Turn 'sticky_actions' into a set of active actions (strongly typed).
|
106 |
+
obs['sticky_actions'] = { sticky_index_to_action[nr] for nr, action in enumerate(obs['sticky_actions']) if action }
|
107 |
+
# Turn 'game_mode' into an enum.
|
108 |
+
obs['game_mode'] = GameMode(obs['game_mode'])
|
109 |
+
# In case of single agent mode, 'designated' is always equal to 'active'.
|
110 |
+
if 'designated' in obs:
|
111 |
+
del obs['designated']
|
112 |
+
# Conver players' roles to enum.
|
113 |
+
obs['left_team_roles'] = [ PlayerRole(role) for role in obs['left_team_roles'] ]
|
114 |
+
obs['right_team_roles'] = [ PlayerRole(role) for role in obs['right_team_roles'] ]
|
115 |
+
|
116 |
+
action = agent(obs)
|
117 |
+
return [action.value]
|
118 |
+
|
119 |
+
return agent_wrapper
|
120 |
+
|
121 |
+
def find_patterns(obs, player_x, player_y):
|
122 |
+
""" find list of appropriate patterns in groups of memory patterns """
|
123 |
+
for get_group in groups_of_memory_patterns:
|
124 |
+
group = get_group(obs, player_x, player_y)
|
125 |
+
if group["environment_fits"](obs, player_x, player_y):
|
126 |
+
return group["get_memory_patterns"](obs, player_x, player_y)
|
127 |
+
|
128 |
+
|
129 |
+
def get_action_of_agent(obs, player_x, player_y):
|
130 |
+
""" get action of appropriate pattern in agent's memory """
|
131 |
+
memory_patterns = find_patterns(obs, player_x, player_y)
|
132 |
+
# find appropriate pattern in list of memory patterns
|
133 |
+
for get_pattern in memory_patterns:
|
134 |
+
pattern = get_pattern(obs, player_x, player_y)
|
135 |
+
if pattern["environment_fits"](obs, player_x, player_y):
|
136 |
+
return pattern["get_action"](obs, player_x, player_y)
|
137 |
+
|
138 |
+
|
139 |
+
def get_distance(x1, y1, right_team):
|
140 |
+
""" get two-dimensional Euclidean distance, considering y size of the field """
|
141 |
+
return math.sqrt((x1 - right_team[0]) ** 2 + (y1 * 2.38 - right_team[1] * 2.38) ** 2)
|
142 |
+
|
143 |
+
|
144 |
+
def run_to_ball_bottom(obs, player_x, player_y):
|
145 |
+
""" run to the ball if it is to the bottom from player's position """
|
146 |
+
def environment_fits(obs, player_x, player_y):
|
147 |
+
""" environment fits constraints """
|
148 |
+
# ball is to the bottom from player's position
|
149 |
+
if (obs["ball"][1] > player_y and
|
150 |
+
abs(obs["ball"][0] - player_x) < 0.01):
|
151 |
+
return True
|
152 |
+
return False
|
153 |
+
|
154 |
+
def get_action(obs, player_x, player_y):
|
155 |
+
""" get action of this memory pattern """
|
156 |
+
return Action.Bottom
|
157 |
+
|
158 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
159 |
+
|
160 |
+
|
161 |
+
def run_to_ball_bottom_left(obs, player_x, player_y):
|
162 |
+
""" run to the ball if it is to the bottom left from player's position """
|
163 |
+
def environment_fits(obs, player_x, player_y):
|
164 |
+
""" environment fits constraints """
|
165 |
+
# ball is to the bottom left from player's position
|
166 |
+
if (obs["ball"][0] < player_x and
|
167 |
+
obs["ball"][1] > player_y):
|
168 |
+
return True
|
169 |
+
return False
|
170 |
+
|
171 |
+
def get_action(obs, player_x, player_y):
|
172 |
+
""" get action of this memory pattern """
|
173 |
+
return Action.BottomLeft
|
174 |
+
|
175 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
176 |
+
|
177 |
+
|
178 |
+
def run_to_ball_bottom_right(obs, player_x, player_y):
|
179 |
+
""" run to the ball if it is to the bottom right from player's position """
|
180 |
+
def environment_fits(obs, player_x, player_y):
|
181 |
+
""" environment fits constraints """
|
182 |
+
# ball is to the bottom right from player's position
|
183 |
+
if (obs["ball"][0] > player_x and
|
184 |
+
obs["ball"][1] > player_y):
|
185 |
+
return True
|
186 |
+
return False
|
187 |
+
|
188 |
+
def get_action(obs, player_x, player_y):
|
189 |
+
""" get action of this memory pattern """
|
190 |
+
return Action.BottomRight
|
191 |
+
|
192 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
193 |
+
|
194 |
+
|
195 |
+
def run_to_ball_left(obs, player_x, player_y):
|
196 |
+
""" run to the ball if it is to the left from player's position """
|
197 |
+
def environment_fits(obs, player_x, player_y):
|
198 |
+
""" environment fits constraints """
|
199 |
+
# ball is to the left from player's position
|
200 |
+
if (obs["ball"][0] < player_x and
|
201 |
+
abs(obs["ball"][1] - player_y) < 0.01):
|
202 |
+
return True
|
203 |
+
return False
|
204 |
+
|
205 |
+
def get_action(obs, player_x, player_y):
|
206 |
+
""" get action of this memory pattern """
|
207 |
+
return Action.Left
|
208 |
+
|
209 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
210 |
+
|
211 |
+
|
212 |
+
def run_to_ball_right(obs, player_x, player_y):
|
213 |
+
""" run to the ball if it is to the right from player's position """
|
214 |
+
def environment_fits(obs, player_x, player_y):
|
215 |
+
""" environment fits constraints """
|
216 |
+
# ball is to the right from player's position
|
217 |
+
if (obs["ball"][0] > player_x and
|
218 |
+
abs(obs["ball"][1] - player_y) < 0.01):
|
219 |
+
return True
|
220 |
+
return False
|
221 |
+
|
222 |
+
def get_action(obs, player_x, player_y):
|
223 |
+
""" get action of this memory pattern """
|
224 |
+
return Action.Right
|
225 |
+
|
226 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
227 |
+
|
228 |
+
|
229 |
+
def run_to_ball_top(obs, player_x, player_y):
|
230 |
+
""" run to the ball if it is to the top from player's position """
|
231 |
+
def environment_fits(obs, player_x, player_y):
|
232 |
+
""" environment fits constraints """
|
233 |
+
# ball is to the top from player's position
|
234 |
+
if (obs["ball"][1] < player_y and
|
235 |
+
abs(obs["ball"][0] - player_x) < 0.01):
|
236 |
+
return True
|
237 |
+
return False
|
238 |
+
|
239 |
+
def get_action(obs, player_x, player_y):
|
240 |
+
""" get action of this memory pattern """
|
241 |
+
return Action.Top
|
242 |
+
|
243 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
244 |
+
|
245 |
+
|
246 |
+
def run_to_ball_top_left(obs, player_x, player_y):
|
247 |
+
""" run to the ball if it is to the top left from player's position """
|
248 |
+
def environment_fits(obs, player_x, player_y):
|
249 |
+
""" environment fits constraints """
|
250 |
+
# ball is to the top left from player's position
|
251 |
+
if (obs["ball"][0] < player_x and
|
252 |
+
obs["ball"][1] < player_y):
|
253 |
+
return True
|
254 |
+
return False
|
255 |
+
|
256 |
+
def get_action(obs, player_x, player_y):
|
257 |
+
""" get action of this memory pattern """
|
258 |
+
return Action.TopLeft
|
259 |
+
|
260 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
261 |
+
|
262 |
+
|
263 |
+
def run_to_ball_top_right(obs, player_x, player_y):
|
264 |
+
""" run to the ball if it is to the top right from player's position """
|
265 |
+
def environment_fits(obs, player_x, player_y):
|
266 |
+
""" environment fits constraints """
|
267 |
+
# ball is to the top right from player's position
|
268 |
+
if (obs["ball"][0] > player_x and
|
269 |
+
obs["ball"][1] < player_y):
|
270 |
+
return True
|
271 |
+
return False
|
272 |
+
|
273 |
+
def get_action(obs, player_x, player_y):
|
274 |
+
""" get action of this memory pattern """
|
275 |
+
return Action.TopRight
|
276 |
+
|
277 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
278 |
+
|
279 |
+
|
280 |
+
def idle(obs, player_x, player_y):
|
281 |
+
""" do nothing, release all sticky actions """
|
282 |
+
def environment_fits(obs, player_x, player_y):
|
283 |
+
""" environment fits constraints """
|
284 |
+
return True
|
285 |
+
|
286 |
+
def get_action(obs, player_x, player_y):
|
287 |
+
""" get action of this memory pattern """
|
288 |
+
return Action.Idle
|
289 |
+
|
290 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
291 |
+
|
292 |
+
|
293 |
+
def start_sprinting(obs, player_x, player_y):
|
294 |
+
""" make sure player is sprinting """
|
295 |
+
def environment_fits(obs, player_x, player_y):
|
296 |
+
""" environment fits constraints """
|
297 |
+
if Action.Sprint not in obs["sticky_actions"]:
|
298 |
+
return True
|
299 |
+
return False
|
300 |
+
|
301 |
+
def get_action(obs, player_x, player_y):
|
302 |
+
""" get action of this memory pattern """
|
303 |
+
if Action.Dribble in obs['sticky_actions']:
|
304 |
+
return Action.ReleaseDribble
|
305 |
+
return Action.Sprint
|
306 |
+
|
307 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
308 |
+
|
309 |
+
|
310 |
+
def corner(obs, player_x, player_y):
|
311 |
+
""" perform a shot in corner game mode """
|
312 |
+
def environment_fits(obs, player_x, player_y):
|
313 |
+
""" environment fits constraints """
|
314 |
+
# it is corner game mode
|
315 |
+
if obs['game_mode'] == GameMode.Corner:
|
316 |
+
return True
|
317 |
+
return False
|
318 |
+
|
319 |
+
def get_action(obs, player_x, player_y):
|
320 |
+
""" get action of this memory pattern """
|
321 |
+
if player_y > 0:
|
322 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
323 |
+
return Action.TopRight
|
324 |
+
else:
|
325 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
326 |
+
return Action.BottomRight
|
327 |
+
return Action.HighPass
|
328 |
+
|
329 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
330 |
+
|
331 |
+
|
332 |
+
def free_kick(obs, player_x, player_y):
|
333 |
+
""" perform a high pass or a shot in free kick game mode """
|
334 |
+
def environment_fits(obs, player_x, player_y):
|
335 |
+
""" environment fits constraints """
|
336 |
+
# it is free kick game mode
|
337 |
+
if obs['game_mode'] == GameMode.FreeKick:
|
338 |
+
return True
|
339 |
+
return False
|
340 |
+
|
341 |
+
def get_action(obs, player_x, player_y):
|
342 |
+
""" get action of this memory pattern """
|
343 |
+
# shot if player close to goal
|
344 |
+
if player_x > 0.5:
|
345 |
+
if player_y > 0:
|
346 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
347 |
+
return Action.TopRight
|
348 |
+
else:
|
349 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
350 |
+
return Action.BottomRight
|
351 |
+
return Action.Shot
|
352 |
+
# high pass if player far from goal
|
353 |
+
else:
|
354 |
+
if player_y > 0:
|
355 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
356 |
+
return Action.BottomRight
|
357 |
+
else:
|
358 |
+
if Action.TopRight not in obs['sticky_actions']:
|
359 |
+
return Action.TopRight
|
360 |
+
return Action.ShortPass
|
361 |
+
|
362 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
363 |
+
|
364 |
+
|
365 |
+
def goal_kick(obs, player_x, player_y):
|
366 |
+
""" perform a short pass in goal kick game mode """
|
367 |
+
def environment_fits(obs, player_x, player_y):
|
368 |
+
""" environment fits constraints """
|
369 |
+
# it is goal kick game mode
|
370 |
+
if obs['game_mode'] == GameMode.GoalKick:
|
371 |
+
return True
|
372 |
+
return False
|
373 |
+
|
374 |
+
def get_action(obs, player_x, player_y):
|
375 |
+
""" get action of this memory pattern """
|
376 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
377 |
+
return Action.BottomRight
|
378 |
+
return Action.ShortPass
|
379 |
+
|
380 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
381 |
+
|
382 |
+
|
383 |
+
def kick_off(obs, player_x, player_y):
|
384 |
+
""" perform a short pass in kick off game mode """
|
385 |
+
def environment_fits(obs, player_x, player_y):
|
386 |
+
""" environment fits constraints """
|
387 |
+
# it is kick off game mode
|
388 |
+
if obs['game_mode'] == GameMode.KickOff:
|
389 |
+
return True
|
390 |
+
return False
|
391 |
+
|
392 |
+
def get_action(obs, player_x, player_y):
|
393 |
+
""" get action of this memory pattern """
|
394 |
+
if player_y > 0:
|
395 |
+
if Action.Top not in obs["sticky_actions"]:
|
396 |
+
return Action.Top
|
397 |
+
else:
|
398 |
+
if Action.Bottom not in obs["sticky_actions"]:
|
399 |
+
return Action.Bottom
|
400 |
+
return Action.ShortPass
|
401 |
+
|
402 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
403 |
+
|
404 |
+
|
405 |
+
def penalty(obs, player_x, player_y):
|
406 |
+
""" perform a shot in penalty game mode """
|
407 |
+
def environment_fits(obs, player_x, player_y):
|
408 |
+
""" environment fits constraints """
|
409 |
+
# it is penalty game mode
|
410 |
+
if obs['game_mode'] == GameMode.Penalty:
|
411 |
+
return True
|
412 |
+
return False
|
413 |
+
|
414 |
+
def get_action(obs, player_x, player_y):
|
415 |
+
""" get action of this memory pattern """
|
416 |
+
if (random.random() < 0.5 and
|
417 |
+
Action.TopRight not in obs["sticky_actions"] and
|
418 |
+
Action.BottomRight not in obs["sticky_actions"]):
|
419 |
+
return Action.TopRight
|
420 |
+
else:
|
421 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
422 |
+
return Action.BottomRight
|
423 |
+
return Action.Shot
|
424 |
+
|
425 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
426 |
+
|
427 |
+
def throw_in(obs, player_x, player_y):
|
428 |
+
""" perform a short pass in throw in game mode """
|
429 |
+
def environment_fits(obs, player_x, player_y):
|
430 |
+
""" environment fits constraints """
|
431 |
+
# it is throw in game mode
|
432 |
+
if obs['game_mode'] == GameMode.ThrowIn:
|
433 |
+
return True
|
434 |
+
return False
|
435 |
+
|
436 |
+
def get_action(obs, player_x, player_y):
|
437 |
+
""" get action of this memory pattern """
|
438 |
+
if Action.Right not in obs["sticky_actions"]:
|
439 |
+
return Action.Right
|
440 |
+
return Action.ShortPass
|
441 |
+
|
442 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
443 |
+
|
444 |
+
|
445 |
+
def defence_memory_patterns(obs, player_x, player_y):
|
446 |
+
""" group of memory patterns for environments in which opponent's team has the ball """
|
447 |
+
def environment_fits(obs, player_x, player_y):
|
448 |
+
""" environment fits constraints """
|
449 |
+
# player don't have the ball
|
450 |
+
if obs["ball_owned_team"] != 0:
|
451 |
+
return True
|
452 |
+
return False
|
453 |
+
|
454 |
+
def get_memory_patterns(obs, player_x, player_y):
|
455 |
+
""" get list of memory patterns """
|
456 |
+
# shift ball position
|
457 |
+
obs["ball"][0] += obs["ball_direction"][0] * 7
|
458 |
+
obs["ball"][1] += obs["ball_direction"][1] * 3
|
459 |
+
# if opponent has the ball and is far from y axis center
|
460 |
+
if abs(obs["ball"][1]) > 0.07 and obs["ball_owned_team"] == 1:
|
461 |
+
obs["ball"][0] -= 0.01
|
462 |
+
if obs["ball"][1] > 0:
|
463 |
+
obs["ball"][1] -= 0.01
|
464 |
+
else:
|
465 |
+
obs["ball"][1] += 0.01
|
466 |
+
|
467 |
+
memory_patterns = [
|
468 |
+
start_sprinting,
|
469 |
+
run_to_ball_right,
|
470 |
+
run_to_ball_left,
|
471 |
+
run_to_ball_bottom,
|
472 |
+
run_to_ball_top,
|
473 |
+
run_to_ball_top_right,
|
474 |
+
run_to_ball_top_left,
|
475 |
+
run_to_ball_bottom_right,
|
476 |
+
run_to_ball_bottom_left,
|
477 |
+
idle
|
478 |
+
]
|
479 |
+
return memory_patterns
|
480 |
+
|
481 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
482 |
+
|
483 |
+
def goalkeeper_memory_patterns(obs, player_x, player_y):
|
484 |
+
""" group of memory patterns for goalkeeper """
|
485 |
+
def environment_fits(obs, player_x, player_y):
|
486 |
+
""" environment fits constraints """
|
487 |
+
# player is a goalkeeper have the ball
|
488 |
+
if (obs["ball_owned_player"] == obs["active"] and
|
489 |
+
obs["ball_owned_team"] == 0 and
|
490 |
+
obs["ball_owned_player"] == 0):
|
491 |
+
return True
|
492 |
+
return False
|
493 |
+
|
494 |
+
def get_memory_patterns(obs, player_x, player_y):
|
495 |
+
""" get list of memory patterns """
|
496 |
+
memory_patterns = [
|
497 |
+
long_pass_forward,
|
498 |
+
idle
|
499 |
+
]
|
500 |
+
return memory_patterns
|
501 |
+
|
502 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
503 |
+
|
504 |
+
|
505 |
+
def offence_memory_patterns(obs, player_x, player_y):
|
506 |
+
""" group of memory patterns for environments in which player's team has the ball """
|
507 |
+
def environment_fits(obs, player_x, player_y):
|
508 |
+
""" environment fits constraints """
|
509 |
+
# player have the ball
|
510 |
+
if obs["ball_owned_player"] == obs["active"] and obs["ball_owned_team"] == 0:
|
511 |
+
return True
|
512 |
+
return False
|
513 |
+
|
514 |
+
def get_memory_patterns(obs, player_x, player_y):
|
515 |
+
""" get list of memory patterns """
|
516 |
+
memory_patterns = [
|
517 |
+
close_to_goalkeeper_shot,
|
518 |
+
spot_shot,
|
519 |
+
cross,
|
520 |
+
long_pass_forward,
|
521 |
+
keep_the_ball,
|
522 |
+
idle
|
523 |
+
]
|
524 |
+
return memory_patterns
|
525 |
+
|
526 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
527 |
+
|
528 |
+
|
529 |
+
def other_memory_patterns(obs, player_x, player_y):
|
530 |
+
""" group of memory patterns for all other environments """
|
531 |
+
def environment_fits(obs, player_x, player_y):
|
532 |
+
""" environment fits constraints """
|
533 |
+
return True
|
534 |
+
|
535 |
+
def get_memory_patterns(obs, player_x, player_y):
|
536 |
+
""" get list of memory patterns """
|
537 |
+
memory_patterns = [
|
538 |
+
idle
|
539 |
+
]
|
540 |
+
return memory_patterns
|
541 |
+
|
542 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
543 |
+
|
544 |
+
def special_game_modes_memory_patterns(obs, player_x, player_y):
|
545 |
+
""" group of memory patterns for special game mode environments """
|
546 |
+
def environment_fits(obs, player_x, player_y):
|
547 |
+
""" environment fits constraints """
|
548 |
+
# if game mode is not normal
|
549 |
+
if obs['game_mode'] != GameMode.Normal:
|
550 |
+
return True
|
551 |
+
return False
|
552 |
+
|
553 |
+
def get_memory_patterns(obs, player_x, player_y):
|
554 |
+
""" get list of memory patterns """
|
555 |
+
memory_patterns = [
|
556 |
+
corner,
|
557 |
+
free_kick,
|
558 |
+
goal_kick,
|
559 |
+
kick_off,
|
560 |
+
penalty,
|
561 |
+
throw_in,
|
562 |
+
idle
|
563 |
+
]
|
564 |
+
return memory_patterns
|
565 |
+
|
566 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
567 |
+
|
568 |
+
|
569 |
+
def special_spot_shot(obs, player_x, player_y):
|
570 |
+
""" group of memory patterns for special game mode environments """
|
571 |
+
def environment_fits(obs, player_x, player_y):
|
572 |
+
""" environment fits constraints """
|
573 |
+
# if game mode is not normal
|
574 |
+
if player_x > 0.8 and abs(player_y) < 0.21:
|
575 |
+
return True
|
576 |
+
return False
|
577 |
+
|
578 |
+
def get_memory_patterns(obs, player_x, player_y):
|
579 |
+
""" get list of memory patterns """
|
580 |
+
memory_patterns = [
|
581 |
+
shot,
|
582 |
+
idle
|
583 |
+
]
|
584 |
+
return memory_patterns
|
585 |
+
|
586 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
587 |
+
|
588 |
+
|
589 |
+
def own_goal(obs, player_x, player_y):
|
590 |
+
""" group of memory patterns for special game mode environments """
|
591 |
+
def environment_fits(obs, player_x, player_y):
|
592 |
+
""" environment fits constraints """
|
593 |
+
# if game mode is not normal
|
594 |
+
if player_x < -0.9 and player_y:
|
595 |
+
return True
|
596 |
+
return False
|
597 |
+
|
598 |
+
def get_memory_patterns(obs, player_x, player_y):
|
599 |
+
""" get list of memory patterns """
|
600 |
+
memory_patterns = [
|
601 |
+
own_goal_2
|
602 |
+
]
|
603 |
+
return memory_patterns
|
604 |
+
|
605 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
606 |
+
|
607 |
+
def get_best_direction(obs, target_direction):
|
608 |
+
active_position = obs["left_team"][obs["active"]]
|
609 |
+
relative_goal_position = np.array(target_direction) - active_position
|
610 |
+
all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
|
611 |
+
best_direction = np.argmax([np.dot(relative_goal_position, v) for v in all_directions_vecs])
|
612 |
+
return ALL_DIRECTION_ACTIONS[best_direction]
|
613 |
+
|
614 |
+
def get_distance2ball(obs):
|
615 |
+
return np.linalg.norm(obs["ball"][:2] - obs["left_team"][obs['active']])
|
616 |
+
|
617 |
+
def get_target2line(obs):
|
618 |
+
active_position = obs["left_team"][obs["active"]]
|
619 |
+
ball_x, ball_y = obs['ball'][0], obs['ball'][1]
|
620 |
+
distance2goal = ((ball_x + 1) ** 2 + ball_y ** 2) ** 0.5 + 1e-5
|
621 |
+
cos_theta = (ball_x + 1) / distance2goal
|
622 |
+
sin_theta = ball_y / distance2goal
|
623 |
+
target_pos = np.array([0.03 * cos_theta - 1, 0.03 * sin_theta])
|
624 |
+
return target_pos
|
625 |
+
|
626 |
+
def already_near_goal(obs, player_x, player_y):
|
627 |
+
""" do nothing, release all sticky actions """
|
628 |
+
def environment_fits(obs, player_x, player_y):
|
629 |
+
""" environment fits constraints """
|
630 |
+
active_position = obs["left_team"][obs["active"]]
|
631 |
+
relative_goal_position = np.array([-1 + GOAL_BIAS, 0]) - active_position
|
632 |
+
distance2goal = np.linalg.norm(relative_goal_position)
|
633 |
+
if distance2goal < 0.02:
|
634 |
+
return True
|
635 |
+
return False
|
636 |
+
|
637 |
+
def get_action(obs, player_x, player_y):
|
638 |
+
""" get action of this memory pattern """
|
639 |
+
# print(obs["sticky_actions"])
|
640 |
+
if Action.Sprint in obs["sticky_actions"]:
|
641 |
+
return Action.ReleaseSprint
|
642 |
+
if Action.Dribble in obs["sticky_actions"]:
|
643 |
+
return Action.ReleaseDribble
|
644 |
+
if len(obs["sticky_actions"]) > 0:
|
645 |
+
return Action.ReleaseDirection
|
646 |
+
return Action.Idle
|
647 |
+
|
648 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
649 |
+
|
650 |
+
def already_in_line(obs, player_x, player_y):
|
651 |
+
""" do nothing, release all sticky actions """
|
652 |
+
def environment_fits(obs, player_x, player_y):
|
653 |
+
""" environment fits constraints """
|
654 |
+
|
655 |
+
target_pos = get_target2line(obs)
|
656 |
+
distance2goal = np.linalg.norm(target_pos - obs['left_team'][obs['active']])
|
657 |
+
if distance2goal < 0.02:
|
658 |
+
return True
|
659 |
+
return False
|
660 |
+
|
661 |
+
def get_action(obs, player_x, player_y):
|
662 |
+
""" get action of this memory pattern """
|
663 |
+
# print(obs["sticky_actions"])
|
664 |
+
if Action.Sprint in obs["sticky_actions"]:
|
665 |
+
return Action.ReleaseSprint
|
666 |
+
if Action.Dribble in obs["sticky_actions"]:
|
667 |
+
return Action.ReleaseDribble
|
668 |
+
if len(obs["sticky_actions"]) > 0:
|
669 |
+
return Action.ReleaseDirection
|
670 |
+
return Action.Idle
|
671 |
+
|
672 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
673 |
+
|
674 |
+
def run_to_goal(obs, player_x, player_y):
|
675 |
+
def environment_fits(obs, player_x, player_y):
|
676 |
+
""" environment fits constraints """
|
677 |
+
return True
|
678 |
+
|
679 |
+
def get_action(obs, player_x, player_y):
|
680 |
+
# active_position = obs["left_team"][obs["active"]]
|
681 |
+
# relative_goal_position = np.array([-1 + GOAL_BIAS, 0]) - active_position
|
682 |
+
# all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
|
683 |
+
# best_direction = np.argmax([np.dot(relative_goal_position, v) for v in all_directions_vecs])
|
684 |
+
# return ALL_DIRECTION_ACTIONS[best_direction]
|
685 |
+
return get_best_direction(obs, [-1 + GOAL_BIAS, 0])
|
686 |
+
|
687 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
688 |
+
|
689 |
+
def run_to_line(obs, player_x, player_y):
|
690 |
+
def environment_fits(obs, player_x, player_y):
|
691 |
+
""" environment fits constraints """
|
692 |
+
return True
|
693 |
+
|
694 |
+
def get_action(obs, player_x, player_y):
|
695 |
+
target_pos = get_target2line(obs)
|
696 |
+
return get_best_direction(obs, target_pos)
|
697 |
+
|
698 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
699 |
+
|
700 |
+
def goal_keeper_far_pattern(obs, player_x, player_y):
|
701 |
+
def environment_fits(obs, player_x, player_y):
|
702 |
+
""" environment fits constraints """
|
703 |
+
# player have the ball
|
704 |
+
if (obs["active"] == 0):
|
705 |
+
active_position = obs["left_team"][0]
|
706 |
+
relative_ball_position = obs["ball"][:2] - active_position
|
707 |
+
distance2ball = np.linalg.norm(relative_ball_position)
|
708 |
+
if distance2ball > 0.5 or (obs['ball_owned_team'] == 0 and obs['ball_owned_player'] != 0):
|
709 |
+
return True
|
710 |
+
if active_position[0] > -0.7 or abs(active_position[1]) > 0.25:
|
711 |
+
for teammate_pos in obs['left_team'][1:]:
|
712 |
+
teammate_dis = np.linalg.norm(obs["ball"][:2] - teammate_pos)
|
713 |
+
if teammate_dis < distance2ball:
|
714 |
+
return True
|
715 |
+
return False
|
716 |
+
|
717 |
+
def get_memory_patterns(obs, player_x, player_y):
|
718 |
+
""" get list of memory patterns """
|
719 |
+
memory_patterns = [
|
720 |
+
already_near_goal,
|
721 |
+
start_sprinting,
|
722 |
+
run_to_goal
|
723 |
+
]
|
724 |
+
return memory_patterns
|
725 |
+
|
726 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
727 |
+
|
728 |
+
def ball_distance_2_5(obs, player_x, player_y):
|
729 |
+
def environment_fits(obs, player_x, player_y):
|
730 |
+
""" environment fits constraints """
|
731 |
+
# player have the ball
|
732 |
+
if (obs["active"] == 0 and obs['ball_owned_team'] != 0):
|
733 |
+
distance2ball = get_distance2ball(obs)
|
734 |
+
if distance2ball <= 0.5 and distance2ball >= 0.2:
|
735 |
+
return True
|
736 |
+
return False
|
737 |
+
|
738 |
+
def get_memory_patterns(obs, player_x, player_y):
|
739 |
+
""" get list of memory patterns """
|
740 |
+
memory_patterns = [
|
741 |
+
already_in_line,
|
742 |
+
start_sprinting,
|
743 |
+
run_to_line
|
744 |
+
]
|
745 |
+
return memory_patterns
|
746 |
+
|
747 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
748 |
+
|
749 |
+
def ball_distance_close(obs, player_x, player_y):
|
750 |
+
def environment_fits(obs, player_x, player_y):
|
751 |
+
""" environment fits constraints """
|
752 |
+
# player have the ball
|
753 |
+
if (obs["active"] == 0 and obs['ball_owned_team'] != 0):
|
754 |
+
distance2ball = get_distance2ball(obs)
|
755 |
+
if distance2ball < 0.25:
|
756 |
+
return True
|
757 |
+
return False
|
758 |
+
|
759 |
+
def get_memory_patterns(obs, player_x, player_y):
|
760 |
+
""" get list of memory patterns """
|
761 |
+
memory_patterns = [
|
762 |
+
shot
|
763 |
+
]
|
764 |
+
return memory_patterns
|
765 |
+
|
766 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
767 |
+
|
768 |
+
# list of groups of memory patterns
|
769 |
+
groups_of_memory_patterns = [
|
770 |
+
goal_keeper_far_pattern, # 安全
|
771 |
+
goalkeeper_memory_patterns, # 守门员持球
|
772 |
+
# special_spot_shot, # 射门 进不去
|
773 |
+
special_game_modes_memory_patterns, # 特殊game mode
|
774 |
+
ball_distance_2_5,
|
775 |
+
ball_distance_close,
|
776 |
+
# own_goal,
|
777 |
+
# offence_memory_patterns, # 我方持球 进不去
|
778 |
+
defence_memory_patterns,
|
779 |
+
other_memory_patterns # idle
|
780 |
+
]
|
781 |
+
|
782 |
+
|
783 |
+
def keep_the_ball(obs, player_x, player_y):
|
784 |
+
def environment_fits(obs, player_x, player_y):
|
785 |
+
return True
|
786 |
+
|
787 |
+
def get_action(obs, player_x, player_y):
|
788 |
+
right_team, left_team = obs['right_team'], obs['left_team']
|
789 |
+
dist = [get_distance(player_x, player_y, i) for i in right_team]
|
790 |
+
closest = right_team[np.argmin(dist)]
|
791 |
+
near = [i for i in right_team if (i[0] < player_x + 0.2) and (i[0] > player_x) and (i[1] > player_y - 0.05)
|
792 |
+
and (i[1] < player_y + 0.05)]
|
793 |
+
back = [i for i in right_team if (i[0] > player_x)]
|
794 |
+
bottom_right = [i for i in left_team if (i[0] > player_x - 0.05) and (i[0] < player_x + 0.2) and (i[1] < player_y + 0.2) and
|
795 |
+
(i[1] > player_y)]
|
796 |
+
top_right = [i for i in left_team if (i[0] > player_x - 0.05) and (i[0] < player_x + 0.2) and (i[1] > player_y - 0.2) and
|
797 |
+
(i[1] < player_y)]
|
798 |
+
bottom_left = [i for i in left_team if (i[0] < player_x) and (i[0] > player_x - 0.2) and (i[1] < player_y + 0.2) and
|
799 |
+
(i[1] > player_y)]
|
800 |
+
top_left = [i for i in left_team if (i[0] < player_x) and (i[0] > player_x - 0.2) and (i[1] > player_y - 0.2) and
|
801 |
+
(i[1] < player_y)]
|
802 |
+
|
803 |
+
|
804 |
+
if len(near) == 0:
|
805 |
+
return Action.Right
|
806 |
+
|
807 |
+
if player_y > 0:
|
808 |
+
if player_y > 0.35:
|
809 |
+
return Action.Right
|
810 |
+
if len(bottom_right) > 0:
|
811 |
+
if Action.BottomRight not in obs['sticky_actions']:
|
812 |
+
return Action.BottomRight
|
813 |
+
return Action.ShortPass
|
814 |
+
return Action.BottomRight
|
815 |
+
|
816 |
+
if player_y < 0:
|
817 |
+
if player_y < -0.35:
|
818 |
+
return Action.Right
|
819 |
+
if len(top_right) > 0:
|
820 |
+
if Action.TopRight not in obs['sticky_actions']:
|
821 |
+
return Action.TopRight
|
822 |
+
return Action.ShortPass
|
823 |
+
return Action.TopRight
|
824 |
+
|
825 |
+
return {'environment_fits': environment_fits, 'get_action': get_action}
|
826 |
+
|
827 |
+
|
828 |
+
def spot_shot(obs, player_x, player_y):
|
829 |
+
""" shot if close to the goalkeeper """
|
830 |
+
def environment_fits(obs, player_x, player_y):
|
831 |
+
""" environment fits constraints """
|
832 |
+
# shoot if in spotted location
|
833 |
+
if player_x > 0.75 and abs(player_y) < 0.21:
|
834 |
+
return True
|
835 |
+
return False
|
836 |
+
|
837 |
+
|
838 |
+
def get_action(obs, player_x, player_y):
|
839 |
+
""" get action of this memory pattern """
|
840 |
+
if player_y >= 0:
|
841 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
842 |
+
return Action.TopRight
|
843 |
+
else:
|
844 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
845 |
+
return Action.BottomRight
|
846 |
+
return Action.Shot
|
847 |
+
|
848 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
849 |
+
|
850 |
+
|
851 |
+
def cross(obs, player_x, player_y):
|
852 |
+
def environment_fits(obs, player_x, player_y):
|
853 |
+
if player_x > 0.7 and abs(player_y) > 0.21:
|
854 |
+
return True
|
855 |
+
return False
|
856 |
+
|
857 |
+
def get_action(obs, player_x, player_y):
|
858 |
+
|
859 |
+
if player_x > 0.88:
|
860 |
+
if player_y > 0:
|
861 |
+
if Action.Top not in obs['sticky_actions']:
|
862 |
+
return Action.Top
|
863 |
+
else:
|
864 |
+
if Action.Bottom not in obs['sticky_actions']:
|
865 |
+
return Action.Bottom
|
866 |
+
return Action.HighPass
|
867 |
+
|
868 |
+
if player_x > 0.9:
|
869 |
+
if (Action.Right in obs['sticky_actions'] or
|
870 |
+
Action.TopRight in obs['sticky_actions'] or
|
871 |
+
Action.BottomRight in obs['sticky_actions']):
|
872 |
+
return Action.ReleaseDirection
|
873 |
+
if Action.Right not in obs['sticky_actions']:
|
874 |
+
if player_y > 0:
|
875 |
+
if Action.Top not in obs['sticky_actions']:
|
876 |
+
return Action.Top
|
877 |
+
if player_y < 0:
|
878 |
+
if Action.Bottom not in obs['sticky_actions']:
|
879 |
+
return Action.Bottom
|
880 |
+
return Action.HighPass
|
881 |
+
|
882 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
883 |
+
|
884 |
+
|
885 |
+
def close_to_goalkeeper_shot(obs, player_x, player_y):
|
886 |
+
""" shot if close to the goalkeeper """
|
887 |
+
def environment_fits(obs, player_x, player_y):
|
888 |
+
""" environment fits constraints """
|
889 |
+
goalkeeper_x = obs["right_team"][0][0] + obs["right_team_direction"][0][0] * 13
|
890 |
+
goalkeeper_y = obs["right_team"][0][1] + obs["right_team_direction"][0][1] * 13
|
891 |
+
goalkeeper = [goalkeeper_x,goalkeeper_y]
|
892 |
+
|
893 |
+
if get_distance(player_x, player_y, goalkeeper) < 0.25:
|
894 |
+
return True
|
895 |
+
return False
|
896 |
+
|
897 |
+
def get_action(obs, player_x, player_y):
|
898 |
+
""" get action of this memory pattern """
|
899 |
+
if player_y >= 0:
|
900 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
901 |
+
return Action.TopRight
|
902 |
+
else:
|
903 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
904 |
+
return Action.BottomRight
|
905 |
+
return Action.Shot
|
906 |
+
|
907 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
908 |
+
|
909 |
+
|
910 |
+
def long_pass_forward(obs, player_x, player_y):
|
911 |
+
""" perform a high pass, if far from opponent's goal """
|
912 |
+
def environment_fits(obs, player_x, player_y):
|
913 |
+
""" environment fits constraints """
|
914 |
+
right_team = obs["right_team"][1:]
|
915 |
+
# player have the ball and is far from opponent's goal
|
916 |
+
if player_x < -0.4:
|
917 |
+
return True
|
918 |
+
return False
|
919 |
+
|
920 |
+
def get_action(obs, player_x, player_y):
|
921 |
+
""" get action of this memory pattern """
|
922 |
+
right_team, left_team = obs['right_team'], obs['left_team']
|
923 |
+
dist = [get_distance(player_x, player_y, i) for i in right_team]
|
924 |
+
closest = right_team[np.argmin(dist)]
|
925 |
+
|
926 |
+
|
927 |
+
if abs(player_y) > 0.22:
|
928 |
+
if Action.Right not in obs["sticky_actions"]:
|
929 |
+
return Action.Right
|
930 |
+
return Action.HighPass
|
931 |
+
|
932 |
+
if np.min(dist) > 0.4:
|
933 |
+
if player_y > 0:
|
934 |
+
return Action.Bottom
|
935 |
+
else:
|
936 |
+
return Action.Top
|
937 |
+
|
938 |
+
if np.min(dist) < 0.4 and np.min(dist) > 0.2:
|
939 |
+
if player_y < 0:
|
940 |
+
return Action.TopRight
|
941 |
+
else:
|
942 |
+
return Action.BottomRight
|
943 |
+
|
944 |
+
if np.min(dist) < 0.2:
|
945 |
+
if Action.Right not in obs['sticky_actions']:
|
946 |
+
return Action.Right
|
947 |
+
return Action.HighPass
|
948 |
+
|
949 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
950 |
+
|
951 |
+
def shot(obs, player_x, player_y):
|
952 |
+
def environment_fits(obs, player_x, player_y):
|
953 |
+
return True
|
954 |
+
|
955 |
+
def get_action(obs, player_x, player_y):
|
956 |
+
# if player_y > 0:
|
957 |
+
# if Action.TopRight not in obs['sticky_actions']:
|
958 |
+
# return Action.TopRight
|
959 |
+
# else:
|
960 |
+
# if Action.BottomRight not in obs['sticky_actions']:
|
961 |
+
# return Action.BottomRight
|
962 |
+
return Action.Shot
|
963 |
+
|
964 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
965 |
+
|
966 |
+
|
967 |
+
def own_goal_2(obs, player_x, player_y):
|
968 |
+
def environment_fits(obs, player_x, player_y):
|
969 |
+
return True
|
970 |
+
|
971 |
+
def get_action(obs, player_x, player_y):
|
972 |
+
return Action.Shot
|
973 |
+
|
974 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
975 |
+
|
976 |
+
|
977 |
+
# @human_readable_agent wrapper modifies raw observations
|
978 |
+
# provided by the environment:
|
979 |
+
# https://github.com/google-research/football/blob/master/gfootball/doc/observation.md#raw-observations
|
980 |
+
# into a form easier to work with by humans.
|
981 |
+
# Following modifications are applied:
|
982 |
+
# - Action, PlayerRole and GameMode enums are introduced.
|
983 |
+
# - 'sticky_actions' are turned into a set of active actions (Action enum)
|
984 |
+
# see usage example below.
|
985 |
+
# - 'game_mode' is turned into GameMode enum.
|
986 |
+
# - 'designated' field is removed, as it always equals to 'active'
|
987 |
+
# when a single player is controlled on the team.
|
988 |
+
# - 'left_team_roles'/'right_team_roles' are turned into PlayerRole enums.
|
989 |
+
# - Action enum is to be returned by the agent function.
|
990 |
+
@human_readable_agent
|
991 |
+
def agent_get_action(obs):
|
992 |
+
""" Ole ole ole ole """
|
993 |
+
# dictionary for Memory Patterns data
|
994 |
+
obs["memory_patterns"] = {}
|
995 |
+
# We always control left team (observations and actions
|
996 |
+
# are mirrored appropriately by the environment).
|
997 |
+
controlled_player_pos = obs["left_team"][obs["active"]]
|
998 |
+
# get action of appropriate pattern in agent's memory
|
999 |
+
action = get_action_of_agent(obs, controlled_player_pos[0], controlled_player_pos[1])
|
1000 |
+
# return action
|
1001 |
+
return action
|
openrl_policy.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright 2023 The OpenRL Authors.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.distributions import Categorical
|
22 |
+
|
23 |
+
import gym
|
24 |
+
|
25 |
+
def check(input):
|
26 |
+
output = torch.from_numpy(input) if type(input) == np.ndarray else input
|
27 |
+
return output
|
28 |
+
|
29 |
+
class FcEncoder(nn.Module):
|
30 |
+
def __init__(self, fc_num, input_size, output_size):
|
31 |
+
super(FcEncoder, self).__init__()
|
32 |
+
self.first_mlp = nn.Sequential(
|
33 |
+
nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)
|
34 |
+
)
|
35 |
+
self.mlp = nn.Sequential()
|
36 |
+
for _ in range(fc_num - 1):
|
37 |
+
self.mlp.append(nn.Sequential(
|
38 |
+
nn.Linear(output_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)
|
39 |
+
))
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
output = self.first_mlp(x)
|
43 |
+
return self.mlp(output)
|
44 |
+
|
45 |
+
def init(module, weight_init, bias_init, gain=1):
|
46 |
+
weight_init(module.weight.data, gain=gain)
|
47 |
+
if module.bias is not None:
|
48 |
+
bias_init(module.bias.data)
|
49 |
+
return module
|
50 |
+
|
51 |
+
|
52 |
+
class FixedCategorical(torch.distributions.Categorical):
|
53 |
+
def sample(self):
|
54 |
+
return super().sample().unsqueeze(-1)
|
55 |
+
|
56 |
+
def log_probs(self, actions):
|
57 |
+
return (
|
58 |
+
super()
|
59 |
+
.log_prob(actions.squeeze(-1))
|
60 |
+
.view(actions.size(0), -1)
|
61 |
+
.sum(-1)
|
62 |
+
.unsqueeze(-1)
|
63 |
+
)
|
64 |
+
|
65 |
+
def mode(self):
|
66 |
+
return self.probs.argmax(dim=-1, keepdim=True)
|
67 |
+
|
68 |
+
class Categorical(nn.Module):
|
69 |
+
def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
|
70 |
+
super(Categorical, self).__init__()
|
71 |
+
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
|
72 |
+
def init_(m):
|
73 |
+
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
|
74 |
+
|
75 |
+
self.linear = init_(nn.Linear(num_inputs, num_outputs))
|
76 |
+
|
77 |
+
def forward(self, x, available_actions=None):
|
78 |
+
x = self.linear(x)
|
79 |
+
if available_actions is not None:
|
80 |
+
x[available_actions == 0] = -1e10
|
81 |
+
return FixedCategorical(logits=x)
|
82 |
+
|
83 |
+
|
84 |
+
class AddBias(nn.Module):
|
85 |
+
def __init__(self, bias):
|
86 |
+
super(AddBias, self).__init__()
|
87 |
+
self._bias = nn.Parameter(bias.unsqueeze(1))
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
if x.dim() == 2:
|
91 |
+
bias = self._bias.t().view(1, -1)
|
92 |
+
else:
|
93 |
+
bias = self._bias.t().view(1, -1, 1, 1)
|
94 |
+
|
95 |
+
return x + bias
|
96 |
+
|
97 |
+
class ACTLayer(nn.Module):
|
98 |
+
def __init__(self, action_space, inputs_dim, use_orthogonal, gain):
|
99 |
+
super(ACTLayer, self).__init__()
|
100 |
+
self.multidiscrete_action = False
|
101 |
+
self.continuous_action = False
|
102 |
+
self.mixed_action = False
|
103 |
+
|
104 |
+
action_dim = action_space.n
|
105 |
+
self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, x, available_actions=None, deterministic=False):
|
110 |
+
if self.mixed_action :
|
111 |
+
actions = []
|
112 |
+
action_log_probs = []
|
113 |
+
for action_out in self.action_outs:
|
114 |
+
action_logit = action_out(x)
|
115 |
+
action = action_logit.mode() if deterministic else action_logit.sample()
|
116 |
+
action_log_prob = action_logit.log_probs(action)
|
117 |
+
actions.append(action.float())
|
118 |
+
action_log_probs.append(action_log_prob)
|
119 |
+
|
120 |
+
actions = torch.cat(actions, -1)
|
121 |
+
action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
|
122 |
+
|
123 |
+
elif self.multidiscrete_action:
|
124 |
+
actions = []
|
125 |
+
action_log_probs = []
|
126 |
+
for action_out in self.action_outs:
|
127 |
+
action_logit = action_out(x)
|
128 |
+
action = action_logit.mode() if deterministic else action_logit.sample()
|
129 |
+
action_log_prob = action_logit.log_probs(action)
|
130 |
+
actions.append(action)
|
131 |
+
action_log_probs.append(action_log_prob)
|
132 |
+
|
133 |
+
actions = torch.cat(actions, -1)
|
134 |
+
action_log_probs = torch.cat(action_log_probs, -1)
|
135 |
+
|
136 |
+
elif self.continuous_action:
|
137 |
+
action_logits = self.action_out(x)
|
138 |
+
actions = action_logits.mode() if deterministic else action_logits.sample()
|
139 |
+
action_log_probs = action_logits.log_probs(actions)
|
140 |
+
|
141 |
+
else:
|
142 |
+
action_logits = self.action_out(x, available_actions)
|
143 |
+
actions = action_logits.mode() if deterministic else action_logits.sample()
|
144 |
+
action_log_probs = action_logits.log_probs(actions)
|
145 |
+
|
146 |
+
return actions, action_log_probs
|
147 |
+
|
148 |
+
def get_probs(self, x, available_actions=None):
|
149 |
+
if self.mixed_action or self.multidiscrete_action:
|
150 |
+
action_probs = []
|
151 |
+
for action_out in self.action_outs:
|
152 |
+
action_logit = action_out(x)
|
153 |
+
action_prob = action_logit.probs
|
154 |
+
action_probs.append(action_prob)
|
155 |
+
action_probs = torch.cat(action_probs, -1)
|
156 |
+
elif self.continuous_action:
|
157 |
+
action_logits = self.action_out(x)
|
158 |
+
action_probs = action_logits.probs
|
159 |
+
else:
|
160 |
+
action_logits = self.action_out(x, available_actions)
|
161 |
+
action_probs = action_logits.probs
|
162 |
+
|
163 |
+
return action_probs
|
164 |
+
|
165 |
+
def evaluate_actions(self, x, action, available_actions=None, active_masks=None, get_probs=False):
|
166 |
+
if self.mixed_action:
|
167 |
+
a, b = action.split((2, 1), -1)
|
168 |
+
b = b.long()
|
169 |
+
action = [a, b]
|
170 |
+
action_log_probs = []
|
171 |
+
dist_entropy = []
|
172 |
+
for action_out, act in zip(self.action_outs, action):
|
173 |
+
action_logit = action_out(x)
|
174 |
+
action_log_probs.append(action_logit.log_probs(act))
|
175 |
+
if active_masks is not None:
|
176 |
+
if len(action_logit.entropy().shape) == len(active_masks.shape):
|
177 |
+
dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum())
|
178 |
+
else:
|
179 |
+
dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum())
|
180 |
+
else:
|
181 |
+
dist_entropy.append(action_logit.entropy().mean())
|
182 |
+
|
183 |
+
action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
|
184 |
+
dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01
|
185 |
+
|
186 |
+
elif self.multidiscrete_action:
|
187 |
+
action = torch.transpose(action, 0, 1)
|
188 |
+
action_log_probs = []
|
189 |
+
dist_entropy = []
|
190 |
+
for action_out, act in zip(self.action_outs, action):
|
191 |
+
action_logit = action_out(x)
|
192 |
+
action_log_probs.append(action_logit.log_probs(act))
|
193 |
+
if active_masks is not None:
|
194 |
+
dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())
|
195 |
+
else:
|
196 |
+
dist_entropy.append(action_logit.entropy().mean())
|
197 |
+
|
198 |
+
action_log_probs = torch.cat(action_log_probs, -1) # ! could be wrong
|
199 |
+
dist_entropy = torch.tensor(dist_entropy).mean()
|
200 |
+
|
201 |
+
elif self.continuous_action:
|
202 |
+
action_logits = self.action_out(x)
|
203 |
+
action_log_probs = action_logits.log_probs(action)
|
204 |
+
act_entropy = action_logits.entropy()
|
205 |
+
# import pdb;pdb.set_trace()
|
206 |
+
if active_masks is not None:
|
207 |
+
dist_entropy = (act_entropy*active_masks).sum()/active_masks.sum()
|
208 |
+
else:
|
209 |
+
dist_entropy = act_entropy.mean()
|
210 |
+
|
211 |
+
else:
|
212 |
+
action_logits = self.action_out(x, available_actions)
|
213 |
+
action_log_probs = action_logits.log_probs(action)
|
214 |
+
if active_masks is not None:
|
215 |
+
dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()
|
216 |
+
else:
|
217 |
+
dist_entropy = action_logits.entropy().mean()
|
218 |
+
if not get_probs:
|
219 |
+
return action_log_probs, dist_entropy
|
220 |
+
else:
|
221 |
+
return action_log_probs, dist_entropy, action_logits
|
222 |
+
|
223 |
+
class RNNLayer(nn.Module):
|
224 |
+
def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal,rnn_type='gru'):
|
225 |
+
super(RNNLayer, self).__init__()
|
226 |
+
self._recurrent_N = recurrent_N
|
227 |
+
self._use_orthogonal = use_orthogonal
|
228 |
+
self.rnn_type = rnn_type
|
229 |
+
if rnn_type == 'gru':
|
230 |
+
self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
|
231 |
+
elif rnn_type == 'lstm':
|
232 |
+
self.rnn = nn.LSTM(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
|
233 |
+
else:
|
234 |
+
raise NotImplementedError(f'RNN type {rnn_type} has not been implemented.')
|
235 |
+
|
236 |
+
for name, param in self.rnn.named_parameters():
|
237 |
+
if 'bias' in name:
|
238 |
+
nn.init.constant_(param, 0)
|
239 |
+
elif 'weight' in name:
|
240 |
+
if self._use_orthogonal:
|
241 |
+
nn.init.orthogonal_(param)
|
242 |
+
else:
|
243 |
+
nn.init.xavier_uniform_(param)
|
244 |
+
self.norm = nn.LayerNorm(outputs_dim)
|
245 |
+
|
246 |
+
def rnn_forward(self, x, h):
|
247 |
+
if self.rnn_type == 'lstm':
|
248 |
+
h = torch.split(h, h.shape[-1] // 2, dim=-1)
|
249 |
+
h = (h[0].contiguous(), h[1].contiguous())
|
250 |
+
x_, h_ = self.rnn(x, h)
|
251 |
+
if self.rnn_type == 'lstm':
|
252 |
+
h_ = torch.cat(h_, -1)
|
253 |
+
return x_, h_
|
254 |
+
|
255 |
+
def forward(self, x, hxs, masks):
|
256 |
+
if x.size(0) == hxs.size(0):
|
257 |
+
x, hxs = self.rnn_forward(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous())
|
258 |
+
#x= self.gru(x.unsqueeze(0))
|
259 |
+
x = x.squeeze(0)
|
260 |
+
hxs = hxs.transpose(0, 1)
|
261 |
+
else:
|
262 |
+
# x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
|
263 |
+
N = hxs.size(0)
|
264 |
+
T = int(x.size(0) / N)
|
265 |
+
|
266 |
+
# unflatten
|
267 |
+
x = x.view(T, N, x.size(1))
|
268 |
+
|
269 |
+
# Same deal with masks
|
270 |
+
masks = masks.view(T, N)
|
271 |
+
|
272 |
+
# Let's figure out which steps in the sequence have a zero for any agent
|
273 |
+
# We will always assume t=0 has a zero in it as that makes the logic cleaner
|
274 |
+
has_zeros = ((masks[1:] == 0.0)
|
275 |
+
.any(dim=-1)
|
276 |
+
.nonzero()
|
277 |
+
.squeeze()
|
278 |
+
.cpu())
|
279 |
+
|
280 |
+
# +1 to correct the masks[1:]
|
281 |
+
if has_zeros.dim() == 0:
|
282 |
+
# Deal with scalar
|
283 |
+
has_zeros = [has_zeros.item() + 1]
|
284 |
+
else:
|
285 |
+
has_zeros = (has_zeros + 1).numpy().tolist()
|
286 |
+
|
287 |
+
# add t=0 and t=T to the list
|
288 |
+
has_zeros = [0] + has_zeros + [T]
|
289 |
+
|
290 |
+
hxs = hxs.transpose(0, 1)
|
291 |
+
|
292 |
+
outputs = []
|
293 |
+
for i in range(len(has_zeros) - 1):
|
294 |
+
# We can now process steps that don't have any zeros in masks together!
|
295 |
+
# This is much faster
|
296 |
+
start_idx = has_zeros[i]
|
297 |
+
end_idx = has_zeros[i + 1]
|
298 |
+
temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous()
|
299 |
+
rnn_scores, hxs = self.rnn_forward(x[start_idx:end_idx], temp)
|
300 |
+
outputs.append(rnn_scores)
|
301 |
+
|
302 |
+
# assert len(outputs) == T
|
303 |
+
# x is a (T, N, -1) tensor
|
304 |
+
x = torch.cat(outputs, dim=0)
|
305 |
+
|
306 |
+
# flatten
|
307 |
+
x = x.reshape(T * N, -1)
|
308 |
+
hxs = hxs.transpose(0, 1)
|
309 |
+
|
310 |
+
x = self.norm(x)
|
311 |
+
return x, hxs
|
312 |
+
|
313 |
+
|
314 |
+
class InputEncoder(nn.Module):
|
315 |
+
def __init__(self):
|
316 |
+
super(InputEncoder, self).__init__()
|
317 |
+
fc_layer_num = 2
|
318 |
+
fc_output_num = 64
|
319 |
+
self.active_input_num = 87
|
320 |
+
self.ball_owner_input_num = 57
|
321 |
+
self.left_input_num = 88
|
322 |
+
self.right_input_num = 88
|
323 |
+
self.match_state_input_num = 9
|
324 |
+
|
325 |
+
self.active_encoder = FcEncoder(fc_layer_num, self.active_input_num, fc_output_num)
|
326 |
+
self.ball_owner_encoder = FcEncoder(fc_layer_num, self.ball_owner_input_num, fc_output_num)
|
327 |
+
self.left_encoder = FcEncoder(fc_layer_num, self.left_input_num, fc_output_num)
|
328 |
+
self.right_encoder = FcEncoder(fc_layer_num, self.right_input_num, fc_output_num)
|
329 |
+
self.match_state_encoder = FcEncoder(fc_layer_num, self.match_state_input_num, self.match_state_input_num)
|
330 |
+
|
331 |
+
def forward(self, x):
|
332 |
+
active_vec = x[:, :self.active_input_num]
|
333 |
+
ball_owner_vec = x[:, self.active_input_num : self.active_input_num + self.ball_owner_input_num]
|
334 |
+
left_vec = x[:, self.active_input_num + self.ball_owner_input_num : self.active_input_num + self.ball_owner_input_num + self.left_input_num]
|
335 |
+
right_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num : \
|
336 |
+
self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num]
|
337 |
+
match_state_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num:]
|
338 |
+
|
339 |
+
active_output = self.active_encoder(active_vec)
|
340 |
+
ball_owner_output = self.ball_owner_encoder(ball_owner_vec)
|
341 |
+
left_output = self.left_encoder(left_vec)
|
342 |
+
right_output = self.right_encoder(right_vec)
|
343 |
+
match_state_output = self.match_state_encoder(match_state_vec)
|
344 |
+
|
345 |
+
return torch.cat([
|
346 |
+
active_output,
|
347 |
+
ball_owner_output,
|
348 |
+
left_output,
|
349 |
+
right_output,
|
350 |
+
match_state_output
|
351 |
+
], 1)
|
352 |
+
|
353 |
+
def get_fc(input_size, output_size):
|
354 |
+
return nn.Sequential(nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size))
|
355 |
+
|
356 |
+
class ObsEncoder(nn.Module):
|
357 |
+
def __init__(self, input_embedding_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type):
|
358 |
+
super(ObsEncoder, self).__init__()
|
359 |
+
self.input_encoder = InputEncoder() # input先过一遍input encoder
|
360 |
+
self.input_embedding = get_fc(input_embedding_size, hidden_size) # 将encoder输出进行embedding
|
361 |
+
self.rnn = RNNLayer(hidden_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type=rnn_type) # embedding输出过一遍rnn
|
362 |
+
self.after_rnn_mlp = get_fc(hidden_size, hidden_size) # 过了rnn后再过该mlp
|
363 |
+
|
364 |
+
def forward(self, obs, rnn_states, masks):
|
365 |
+
actor_features = self.input_encoder(obs)
|
366 |
+
actor_features = self.input_embedding(actor_features)
|
367 |
+
output, rnn_states = self.rnn(actor_features, rnn_states, masks)
|
368 |
+
return self.after_rnn_mlp(output), rnn_states
|
369 |
+
|
370 |
+
|
371 |
+
class PolicyNetwork(nn.Module):
|
372 |
+
def __init__(self, device=torch.device("cpu")):
|
373 |
+
super(PolicyNetwork, self).__init__()
|
374 |
+
self.tpdv = dict(dtype=torch.float32, device=device)
|
375 |
+
self.device = device
|
376 |
+
self.hidden_size = 256
|
377 |
+
self._use_policy_active_masks = True
|
378 |
+
recurrent_N = 1
|
379 |
+
use_orthogonal = True
|
380 |
+
rnn_type = 'lstm'
|
381 |
+
gain = 0.01
|
382 |
+
action_space = gym.spaces.Discrete(20)
|
383 |
+
self.action_dim = 19
|
384 |
+
input_embedding_size = 64 * 4 + 9
|
385 |
+
self.active_id_size = 1
|
386 |
+
self.id_max = 11
|
387 |
+
|
388 |
+
self.obs_encoder = ObsEncoder(input_embedding_size, self.hidden_size, recurrent_N, use_orthogonal, rnn_type)
|
389 |
+
|
390 |
+
self.predict_id = get_fc(self.hidden_size + self.action_dim, self.id_max) # 其他信息(指除了active_id外的信息)过了rnn和一层mlp后,经过该层来预测id
|
391 |
+
self.id_embedding = get_fc(self.id_max, self.id_max) # active id作为输入,输出和其他信息的feature concat
|
392 |
+
|
393 |
+
self.before_act_wrapper = FcEncoder(2, self.hidden_size + self.id_max, self.hidden_size)
|
394 |
+
self.act = ACTLayer(action_space, self.hidden_size, use_orthogonal, gain)
|
395 |
+
|
396 |
+
self.to(device)
|
397 |
+
|
398 |
+
|
399 |
+
def forward(self, obs, rnn_states, masks=np.concatenate(np.ones((1, 1, 1), dtype=np.float32)), available_actions=None, deterministic=False):
|
400 |
+
obs = check(obs).to(**self.tpdv)
|
401 |
+
if available_actions is not None:
|
402 |
+
available_actions = check(available_actions).to(**self.tpdv)
|
403 |
+
masks = check(masks).to(**self.tpdv)
|
404 |
+
rnn_states = check(rnn_states).to(**self.tpdv)
|
405 |
+
|
406 |
+
active_id = obs[:,:self.active_id_size].squeeze(1).long()
|
407 |
+
id_onehot = torch.eye(self.id_max)[active_id].to(self.device)
|
408 |
+
obs = obs[:,self.active_id_size:]
|
409 |
+
|
410 |
+
obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks)
|
411 |
+
id_output = self.id_embedding(id_onehot)
|
412 |
+
output = torch.cat([id_output, obs_output], 1)
|
413 |
+
|
414 |
+
output = self.before_act_wrapper(output)
|
415 |
+
|
416 |
+
actions, action_log_probs = self.act(output, available_actions, deterministic)
|
417 |
+
return actions, rnn_states
|
418 |
+
|
419 |
+
def eval_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):
|
420 |
+
obs = check(obs).to(**self.tpdv)
|
421 |
+
if available_actions is not None:
|
422 |
+
available_actions = check(available_actions).to(**self.tpdv)
|
423 |
+
if active_masks is not None:
|
424 |
+
active_masks = check(active_masks).to(**self.tpdv)
|
425 |
+
masks = check(masks).to(**self.tpdv)
|
426 |
+
rnn_states = check(rnn_states).to(**self.tpdv)
|
427 |
+
action = check(action).to(**self.tpdv)
|
428 |
+
|
429 |
+
id_groundtruth = obs[:,:self.active_id_size].squeeze(1).long()
|
430 |
+
id_onehot = torch.eye(self.id_max)[id_groundtruth].to(self.device)
|
431 |
+
obs = obs[:,self.active_id_size:]
|
432 |
+
|
433 |
+
obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks)
|
434 |
+
id_output = self.id_embedding(id_onehot)
|
435 |
+
|
436 |
+
action_onehot = torch.eye(self.action_dim)[action.squeeze(1).long()].to(self.device)
|
437 |
+
|
438 |
+
id_prediction = self.predict_id(torch.cat([obs_output, action_onehot], 1))
|
439 |
+
output = torch.cat([id_output, obs_output], 1)
|
440 |
+
|
441 |
+
output = self.before_act_wrapper(output)
|
442 |
+
action_log_probs, dist_entropy = self.act.evaluate_actions(output, action, available_actions,
|
443 |
+
active_masks=active_masks if self._use_policy_active_masks else None)
|
444 |
+
values = None
|
445 |
+
return action_log_probs, dist_entropy, values, id_prediction, id_groundtruth
|
446 |
+
|
openrl_utils.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright 2023 The OpenRL Authors.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
# Area.
|
20 |
+
THIRD_X = 0.3
|
21 |
+
BOX_X = 0.7
|
22 |
+
MAX_X = 1.0
|
23 |
+
BOX_Y = 0.24
|
24 |
+
MAX_Y = 0.42
|
25 |
+
|
26 |
+
# Actions.
|
27 |
+
IDLE = 0
|
28 |
+
LEFT = 1
|
29 |
+
TOP_LEFT = 2
|
30 |
+
TOP = 3
|
31 |
+
TOP_RIGHT = 4
|
32 |
+
RIGHT = 5
|
33 |
+
BOTTOM_RIGHT = 6
|
34 |
+
BOTTOM = 7
|
35 |
+
BOTTOM_LEFT = 8
|
36 |
+
LONG_PASS = 9
|
37 |
+
HIGH_PASS = 10
|
38 |
+
SHORT_PASS = 11
|
39 |
+
SHOT = 12
|
40 |
+
SPRINT = 13
|
41 |
+
RELEASE_DIRECTION = 14
|
42 |
+
RELEASE_SPRINT = 15
|
43 |
+
SLIDING = 16
|
44 |
+
DRIBBLE = 17
|
45 |
+
RELEASE_DRIBBLE = 18
|
46 |
+
STICKY_LEFT = 0
|
47 |
+
STICKY_TOP_LEFT = 1
|
48 |
+
STICKY_TOP = 2
|
49 |
+
STICKY_TOP_RIGHT = 3
|
50 |
+
STICKY_RIGHT = 4
|
51 |
+
STICKY_BOTTOM_RIGHT = 5
|
52 |
+
STICKY_BOTTOM = 6
|
53 |
+
STICKY_BOTTOM_LEFT = 7
|
54 |
+
|
55 |
+
RIGHT_ACTIONS = [TOP_RIGHT, RIGHT, BOTTOM_RIGHT, TOP, BOTTOM]
|
56 |
+
LEFT_ACTIONS = [TOP_LEFT, LEFT, BOTTOM_LEFT, TOP, BOTTOM]
|
57 |
+
BOTTOM_ACTIONS = [BOTTOM_LEFT, BOTTOM, BOTTOM_RIGHT, LEFT, RIGHT]
|
58 |
+
TOP_ACTIONS = [TOP_LEFT, TOP, TOP_RIGHT, LEFT, RIGHT]
|
59 |
+
ALL_DIRECTION_ACTIONS = [LEFT, TOP_LEFT, TOP, TOP_RIGHT, RIGHT, BOTTOM_RIGHT, BOTTOM, BOTTOM_LEFT]
|
60 |
+
ALL_DIRECTION_VECS = [(-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]
|
61 |
+
|
62 |
+
def get_direction_action(available_action, sticky_actions, forbidden_action, target_action, active_direction, need_sprint):
|
63 |
+
available_action = np.zeros(19)
|
64 |
+
available_action[forbidden_action] = 0
|
65 |
+
available_action[target_action] = 1
|
66 |
+
|
67 |
+
if need_sprint:
|
68 |
+
available_action[RELEASE_SPRINT] = 0
|
69 |
+
if sticky_actions[8] == 0:
|
70 |
+
available_action = np.zeros(19)
|
71 |
+
available_action[SPRINT] = 1
|
72 |
+
else:
|
73 |
+
available_action[SPRINT] = 0
|
74 |
+
if sticky_actions[8] == 1:
|
75 |
+
available_action = np.zeros(19)
|
76 |
+
available_action[RELEASE_SPRINT] = 1
|
77 |
+
return available_action
|
78 |
+
|
79 |
+
def openrl_obs_deal(obs):
|
80 |
+
|
81 |
+
direction_x_bound = 0.03
|
82 |
+
direction_y_bound = 0.02
|
83 |
+
ball_direction_x_bound = 0.15
|
84 |
+
ball_direction_y_bound = 0.07
|
85 |
+
ball_direction_z_bound = 4
|
86 |
+
ball_rotation_x_bound = 0.0005
|
87 |
+
ball_rotation_y_bound = 0.0004
|
88 |
+
ball_rotation_z_bound = 0.015
|
89 |
+
active_id = [obs["active"]]
|
90 |
+
assert active_id[0] < 11 and active_id[0] >= 0, "active id is wrong, active id = {}".format(active_id[0])
|
91 |
+
# left team 88
|
92 |
+
left_position = np.concatenate(obs["left_team"])
|
93 |
+
left_direction = np.concatenate(obs["left_team_direction"])
|
94 |
+
left_tired_factor = obs["left_team_tired_factor"]
|
95 |
+
left_yellow_card = obs["left_team_yellow_card"]
|
96 |
+
left_red_card = ~obs["left_team_active"]
|
97 |
+
left_offside = np.zeros(11)
|
98 |
+
if obs["ball_owned_team"] == 0:
|
99 |
+
left_offside_line = max(0, obs["ball"][0], np.sort(obs["right_team"][:, 0])[-2])
|
100 |
+
left_offside = obs["left_team"][:, 0] > left_offside_line
|
101 |
+
left_offside[obs["ball_owned_player"]] = False
|
102 |
+
|
103 |
+
new_left_direction = left_direction.copy()
|
104 |
+
for counting in range(len(new_left_direction)):
|
105 |
+
new_left_direction[counting] = new_left_direction[counting] / direction_x_bound if counting % 2 == 0 else new_left_direction[counting] / direction_y_bound
|
106 |
+
|
107 |
+
left_team = np.concatenate([
|
108 |
+
left_position,
|
109 |
+
new_left_direction,
|
110 |
+
left_tired_factor,
|
111 |
+
left_yellow_card,
|
112 |
+
left_red_card,
|
113 |
+
left_offside,
|
114 |
+
]).astype(np.float64)
|
115 |
+
|
116 |
+
# right team 88
|
117 |
+
right_position = np.concatenate(obs["right_team"])
|
118 |
+
right_direction = np.concatenate(obs["right_team_direction"])
|
119 |
+
right_tired_factor = obs["right_team_tired_factor"]
|
120 |
+
right_yellow_card = obs["right_team_yellow_card"]
|
121 |
+
right_red_card = ~obs["right_team_active"]
|
122 |
+
right_offside = np.zeros(11)
|
123 |
+
if obs["ball_owned_team"] == 1:
|
124 |
+
right_offside_line = min(0, obs["ball"][0], np.sort(obs["left_team"][:, 0])[1])
|
125 |
+
right_offside = obs["right_team"][:, 0] < right_offside_line
|
126 |
+
right_offside[obs["ball_owned_player"]] = False
|
127 |
+
|
128 |
+
new_right_direction = right_direction.copy()
|
129 |
+
for counting in range(len(new_right_direction)):
|
130 |
+
new_right_direction[counting] = new_right_direction[counting] / direction_x_bound if counting % 2 == 0 else new_right_direction[counting] / direction_y_bound
|
131 |
+
|
132 |
+
right_team = np.concatenate([
|
133 |
+
right_position,
|
134 |
+
new_right_direction,
|
135 |
+
right_tired_factor,
|
136 |
+
right_yellow_card,
|
137 |
+
right_red_card,
|
138 |
+
right_offside,
|
139 |
+
]).astype(np.float64)
|
140 |
+
|
141 |
+
# active 18
|
142 |
+
sticky_actions = obs["sticky_actions"][:10]
|
143 |
+
active_position = obs["left_team"][obs["active"]]
|
144 |
+
active_direction = obs["left_team_direction"][obs["active"]]
|
145 |
+
active_tired_factor = left_tired_factor[obs["active"]]
|
146 |
+
active_yellow_card = left_yellow_card[obs["active"]]
|
147 |
+
active_red_card = left_red_card[obs["active"]]
|
148 |
+
active_offside = left_offside[obs["active"]]
|
149 |
+
|
150 |
+
new_active_direction = active_direction.copy()
|
151 |
+
new_active_direction[0] /= direction_x_bound
|
152 |
+
new_active_direction[1] /= direction_y_bound
|
153 |
+
|
154 |
+
active_player = np.concatenate([
|
155 |
+
sticky_actions,
|
156 |
+
active_position,
|
157 |
+
new_active_direction,
|
158 |
+
[active_tired_factor],
|
159 |
+
[active_yellow_card],
|
160 |
+
[active_red_card],
|
161 |
+
[active_offside],
|
162 |
+
]).astype(np.float64)
|
163 |
+
|
164 |
+
# relative 69
|
165 |
+
relative_ball_position = obs["ball"][:2] - active_position
|
166 |
+
distance2ball = np.linalg.norm(relative_ball_position)
|
167 |
+
relative_left_position = obs["left_team"] - active_position
|
168 |
+
distance2left = np.linalg.norm(relative_left_position, axis=1)
|
169 |
+
relative_left_position = np.concatenate(relative_left_position)
|
170 |
+
relative_right_position = obs["right_team"] - active_position
|
171 |
+
distance2right = np.linalg.norm(relative_right_position, axis=1)
|
172 |
+
relative_right_position = np.concatenate(relative_right_position)
|
173 |
+
relative_info = np.concatenate([
|
174 |
+
relative_ball_position,
|
175 |
+
[distance2ball],
|
176 |
+
relative_left_position,
|
177 |
+
distance2left,
|
178 |
+
relative_right_position,
|
179 |
+
distance2right,
|
180 |
+
]).astype(np.float64)
|
181 |
+
|
182 |
+
active_info = np.concatenate([active_player, relative_info]) # 87
|
183 |
+
|
184 |
+
# ball info 12
|
185 |
+
ball_owned_team = np.zeros(3)
|
186 |
+
ball_owned_team[obs["ball_owned_team"] + 1] = 1.0
|
187 |
+
new_ball_direction = obs["ball_direction"].copy()
|
188 |
+
new_ball_rotation = obs['ball_rotation'].copy()
|
189 |
+
for counting in range(len(new_ball_direction)):
|
190 |
+
if counting % 3 == 0:
|
191 |
+
new_ball_direction[counting] /= ball_direction_x_bound
|
192 |
+
new_ball_rotation[counting] /= ball_rotation_x_bound
|
193 |
+
if counting % 3 == 1:
|
194 |
+
new_ball_direction[counting] /= ball_direction_y_bound
|
195 |
+
new_ball_rotation[counting] /= ball_rotation_y_bound
|
196 |
+
if counting % 3 == 2:
|
197 |
+
new_ball_direction[counting] /= ball_direction_z_bound
|
198 |
+
new_ball_rotation[counting] /= ball_rotation_z_bound
|
199 |
+
ball_info = np.concatenate([
|
200 |
+
obs["ball"],
|
201 |
+
new_ball_direction,
|
202 |
+
ball_owned_team,
|
203 |
+
new_ball_rotation
|
204 |
+
]).astype(np.float64)
|
205 |
+
# ball owned player 23
|
206 |
+
ball_owned_player = np.zeros(23)
|
207 |
+
if obs["ball_owned_team"] == 1: # 对手
|
208 |
+
ball_owned_player[11 + obs['ball_owned_player']] = 1.0
|
209 |
+
ball_owned_player_pos = obs['right_team'][obs['ball_owned_player']]
|
210 |
+
ball_owned_player_direction = obs["right_team_direction"][obs['ball_owned_player']]
|
211 |
+
ball_owner_tired_factor = right_tired_factor[obs['ball_owned_player']]
|
212 |
+
ball_owner_yellow_card = right_yellow_card[obs['ball_owned_player']]
|
213 |
+
ball_owner_red_card = right_red_card[obs['ball_owned_player']]
|
214 |
+
ball_owner_offside = right_offside[obs['ball_owned_player']]
|
215 |
+
elif obs["ball_owned_team"] == 0:
|
216 |
+
ball_owned_player[obs['ball_owned_player']] = 1.0
|
217 |
+
ball_owned_player_pos = obs['left_team'][obs['ball_owned_player']]
|
218 |
+
ball_owned_player_direction = obs["left_team_direction"][obs['ball_owned_player']]
|
219 |
+
ball_owner_tired_factor = left_tired_factor[obs['ball_owned_player']]
|
220 |
+
ball_owner_yellow_card = left_yellow_card[obs['ball_owned_player']]
|
221 |
+
ball_owner_red_card = left_red_card[obs['ball_owned_player']]
|
222 |
+
ball_owner_offside = left_offside[obs['ball_owned_player']]
|
223 |
+
else:
|
224 |
+
ball_owned_player[-1] = 1.0
|
225 |
+
ball_owned_player_pos = np.zeros(2)
|
226 |
+
ball_owned_player_direction = np.zeros(2)
|
227 |
+
|
228 |
+
relative_ball_owner_position = np.zeros(2)
|
229 |
+
distance2ballowner = np.zeros(1)
|
230 |
+
ball_owner_info = np.zeros(4)
|
231 |
+
if obs["ball_owned_team"] != -1:
|
232 |
+
relative_ball_owner_position = ball_owned_player_pos - active_position
|
233 |
+
distance2ballowner = [np.linalg.norm(relative_ball_owner_position)]
|
234 |
+
ball_owner_info = np.concatenate([
|
235 |
+
[ball_owner_tired_factor],
|
236 |
+
[ball_owner_yellow_card],
|
237 |
+
[ball_owner_red_card],
|
238 |
+
[ball_owner_offside]
|
239 |
+
])
|
240 |
+
|
241 |
+
new_ball_owned_player_direction = ball_owned_player_direction.copy()
|
242 |
+
new_ball_owned_player_direction[0] /= direction_x_bound
|
243 |
+
new_ball_owned_player_direction[1] /= direction_y_bound
|
244 |
+
|
245 |
+
ball_own_active_info = np.concatenate([
|
246 |
+
ball_info, # 12
|
247 |
+
ball_owned_player, # 23
|
248 |
+
active_position, # 2
|
249 |
+
new_active_direction, # 2
|
250 |
+
[active_tired_factor], # 1
|
251 |
+
[active_yellow_card], # 1
|
252 |
+
[active_red_card], # 1
|
253 |
+
[active_offside], # 1
|
254 |
+
relative_ball_position, # 2
|
255 |
+
[distance2ball], # 1
|
256 |
+
ball_owned_player_pos, # 2
|
257 |
+
new_ball_owned_player_direction, # 2
|
258 |
+
relative_ball_owner_position, # 2
|
259 |
+
distance2ballowner, # 1
|
260 |
+
ball_owner_info # 4
|
261 |
+
])
|
262 |
+
|
263 |
+
# match state
|
264 |
+
game_mode = np.zeros(7)
|
265 |
+
game_mode[obs["game_mode"]] = 1.0
|
266 |
+
goal_diff_ratio = (obs["score"][0] - obs["score"][1]) / 5
|
267 |
+
steps_left_ratio = obs["steps_left"] / 3001
|
268 |
+
match_state = np.concatenate([
|
269 |
+
game_mode,
|
270 |
+
[goal_diff_ratio],
|
271 |
+
[steps_left_ratio],
|
272 |
+
]).astype(np.float64)
|
273 |
+
|
274 |
+
# available action
|
275 |
+
available_action = np.ones(19)
|
276 |
+
available_action[IDLE] = 0
|
277 |
+
available_action[RELEASE_DIRECTION] = 0
|
278 |
+
should_left = False
|
279 |
+
|
280 |
+
|
281 |
+
if obs["game_mode"] == 0:
|
282 |
+
active_x = active_position[0]
|
283 |
+
counting_right_enemy_num = 0
|
284 |
+
counting_right_teammate_num = 0
|
285 |
+
counting_left_teammate_num = 0
|
286 |
+
for enemy_pos in obs["right_team"][1:]:
|
287 |
+
if active_x < enemy_pos[0]:
|
288 |
+
counting_right_enemy_num += 1
|
289 |
+
for teammate_pos in obs["left_team"][1:]:
|
290 |
+
if active_x < teammate_pos[0]:
|
291 |
+
counting_right_teammate_num += 1
|
292 |
+
if active_x > teammate_pos[0]:
|
293 |
+
counting_left_teammate_num += 1
|
294 |
+
|
295 |
+
if active_x > obs['ball'][0] + 0.05:
|
296 |
+
|
297 |
+
if counting_left_teammate_num < 2:
|
298 |
+
|
299 |
+
if obs['ball_owned_team'] != 0:
|
300 |
+
should_left = True
|
301 |
+
if should_left:
|
302 |
+
available_action = get_direction_action(available_action, sticky_actions, RIGHT_ACTIONS, [LEFT, BOTTOM_LEFT, TOP_LEFT], active_direction, True)
|
303 |
+
|
304 |
+
|
305 |
+
if (abs(relative_ball_position[0]) > 0.75 or abs(relative_ball_position[1]) > 0.5):
|
306 |
+
all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
|
307 |
+
best_direction = np.argmax([np.dot(relative_ball_position, v) for v in all_directions_vecs])
|
308 |
+
target_direction = ALL_DIRECTION_ACTIONS[best_direction]
|
309 |
+
forbidden_actions = ALL_DIRECTION_ACTIONS.copy()
|
310 |
+
forbidden_actions.remove(target_direction)
|
311 |
+
available_action = get_direction_action(available_action, sticky_actions, forbidden_actions, [target_direction], active_direction, True)
|
312 |
+
|
313 |
+
|
314 |
+
if_i_hold_ball = (obs["ball_owned_team"] == 0 and obs["ball_owned_player"] == obs['active'])
|
315 |
+
ball_pos_offset = 0.05
|
316 |
+
no_ball_pos_offset = 0.03
|
317 |
+
|
318 |
+
active_x, active_y = active_position[0], active_position[1]
|
319 |
+
if_outside = False
|
320 |
+
if active_x <= (-1 + no_ball_pos_offset) or (if_i_hold_ball and active_x <= (-1 + ball_pos_offset)):
|
321 |
+
if_outside = True
|
322 |
+
action_index = LEFT_ACTIONS
|
323 |
+
target_direction = RIGHT
|
324 |
+
elif active_x >= (1 - no_ball_pos_offset) or (if_i_hold_ball and active_x >= (1 - ball_pos_offset)):
|
325 |
+
if_outside = True
|
326 |
+
action_index = RIGHT_ACTIONS
|
327 |
+
target_direction = LEFT
|
328 |
+
elif active_y >= (0.42 - no_ball_pos_offset) or (if_i_hold_ball and active_y >= (0.42 - ball_pos_offset)):
|
329 |
+
if_outside = True
|
330 |
+
action_index = BOTTOM_ACTIONS
|
331 |
+
target_direction = TOP
|
332 |
+
elif active_y <= (-0.42 + no_ball_pos_offset) or (if_i_hold_ball and active_x <= (-0.42 + ball_pos_offset)):
|
333 |
+
if_outside = True
|
334 |
+
action_index = TOP_ACTIONS
|
335 |
+
target_direction = BOTTOM
|
336 |
+
if obs["game_mode"] in [1, 2, 3, 4, 5]:
|
337 |
+
left2ball = np.linalg.norm(obs["left_team"] - obs["ball"][:2], axis=1)
|
338 |
+
right2ball = np.linalg.norm(obs["right_team"] - obs["ball"][:2], axis=1)
|
339 |
+
if np.min(left2ball) < np.min(right2ball) and obs["active"] == np.argmin(left2ball):
|
340 |
+
if_outside = False
|
341 |
+
elif obs["game_mode"] in [6]:
|
342 |
+
if obs["ball"][0] > 0 and active_position[0] > BOX_X:
|
343 |
+
if_outside = False
|
344 |
+
if if_outside:
|
345 |
+
available_action = get_direction_action(available_action, sticky_actions, action_index, [target_direction], active_direction, False)
|
346 |
+
|
347 |
+
if np.sum(sticky_actions[:8]) == 0:
|
348 |
+
available_action[RELEASE_DIRECTION] = 0
|
349 |
+
if sticky_actions[8] == 0:
|
350 |
+
available_action[RELEASE_SPRINT] = 0
|
351 |
+
else:
|
352 |
+
available_action[SPRINT] = 0
|
353 |
+
if sticky_actions[9] == 0:
|
354 |
+
available_action[RELEASE_DRIBBLE] = 0
|
355 |
+
else:
|
356 |
+
available_action[DRIBBLE] = 0
|
357 |
+
if active_position[0] < 0.4 or abs(active_position[1]) > 0.3:
|
358 |
+
available_action[SHOT] = 0
|
359 |
+
|
360 |
+
if obs["game_mode"] == 0:
|
361 |
+
if obs["ball_owned_team"] == -1:
|
362 |
+
available_action[DRIBBLE] = 0
|
363 |
+
if distance2ball >= 0.05:
|
364 |
+
available_action[SLIDING] = 0
|
365 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
|
366 |
+
elif obs["ball_owned_team"] == 0:
|
367 |
+
available_action[SLIDING] = 0
|
368 |
+
if distance2ball >= 0.05:
|
369 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, DRIBBLE]] = 0
|
370 |
+
elif obs["ball_owned_team"] == 1:
|
371 |
+
available_action[DRIBBLE] = 0
|
372 |
+
if distance2ball >= 0.05:
|
373 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, SLIDING]] = 0
|
374 |
+
elif obs["game_mode"] in [1, 2, 3, 4, 5]:
|
375 |
+
left2ball = np.linalg.norm(obs["left_team"] - obs["ball"][:2], axis=1)
|
376 |
+
right2ball = np.linalg.norm(obs["right_team"] - obs["ball"][:2], axis=1)
|
377 |
+
if np.min(left2ball) < np.min(right2ball) and obs["active"] == np.argmin(left2ball):
|
378 |
+
available_action[[SPRINT, RELEASE_SPRINT, SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
379 |
+
else:
|
380 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
|
381 |
+
available_action[[SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
382 |
+
elif obs["game_mode"] == 6:
|
383 |
+
if obs["ball"][0] > 0 and active_position[0] > BOX_X:
|
384 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS]] = 0
|
385 |
+
available_action[[SPRINT, RELEASE_SPRINT, SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
386 |
+
else:
|
387 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
|
388 |
+
available_action[[SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
389 |
+
|
390 |
+
|
391 |
+
obs = np.concatenate([
|
392 |
+
active_id, # 1
|
393 |
+
active_info, # 87
|
394 |
+
ball_own_active_info, # 57
|
395 |
+
left_team, # 88
|
396 |
+
right_team, # 88
|
397 |
+
match_state, # 9
|
398 |
+
])
|
399 |
+
|
400 |
+
share_obs = np.concatenate([
|
401 |
+
ball_info, # 12
|
402 |
+
ball_owned_player, # 23
|
403 |
+
left_team, # 88
|
404 |
+
right_team, # 88
|
405 |
+
match_state, # 9
|
406 |
+
])
|
407 |
+
|
408 |
+
assert available_action.sum() > 0
|
409 |
+
return dict(
|
410 |
+
obs=obs,
|
411 |
+
share_obs=share_obs,
|
412 |
+
available_action=available_action,
|
413 |
+
)
|
414 |
+
|
415 |
+
|
416 |
+
def _t2n(x):
|
417 |
+
return x.detach().cpu().numpy()
|
418 |
+
|
419 |
+
|
420 |
+
|
421 |
+
|
submission.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright 2023 The OpenRL Authors.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
""""""
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
from pathlib import Path
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
base_dir = Path(__file__).resolve().parent
|
26 |
+
sys.path.append(str(base_dir))
|
27 |
+
|
28 |
+
from openrl_policy import PolicyNetwork
|
29 |
+
from openrl_utils import openrl_obs_deal, _t2n
|
30 |
+
from goal_keeper import agent_get_action
|
31 |
+
|
32 |
+
class OpenRLAgent():
|
33 |
+
def __init__(self):
|
34 |
+
rnn_shape = [1,1,1,512]
|
35 |
+
self.rnn_hidden_state = [np.zeros(rnn_shape, dtype=np.float32) for _ in range (11)]
|
36 |
+
self.model = PolicyNetwork()
|
37 |
+
self.model.load_state_dict(torch.load( os.path.dirname(os.path.abspath(__file__)) + '/actor.pt', map_location=torch.device("cpu")))
|
38 |
+
self.model.eval()
|
39 |
+
|
40 |
+
def get_action(self,raw_obs,idx):
|
41 |
+
if idx == 0:
|
42 |
+
re_action = [[0]*19]
|
43 |
+
re_action_index = agent_get_action(raw_obs)[0]
|
44 |
+
re_action[0][re_action_index] = 1
|
45 |
+
return re_action
|
46 |
+
|
47 |
+
openrl_obs = openrl_obs_deal(raw_obs)
|
48 |
+
|
49 |
+
obs = openrl_obs['obs']
|
50 |
+
obs = np.concatenate(obs.reshape(1, 1, 330))
|
51 |
+
rnn_hidden_state = np.concatenate(self.rnn_hidden_state[idx])
|
52 |
+
avail_actions = np.zeros(20)
|
53 |
+
avail_actions[:19] = openrl_obs['available_action']
|
54 |
+
avail_actions = np.concatenate(avail_actions.reshape([1, 1, 20]))
|
55 |
+
with torch.no_grad():
|
56 |
+
actions, rnn_hidden_state = self.model(obs, rnn_hidden_state, available_actions=avail_actions, deterministic=True)
|
57 |
+
if actions[0][0] == 17 and raw_obs["sticky_actions"][8] == 1:
|
58 |
+
actions[0][0] = 15
|
59 |
+
self.rnn_hidden_state[idx] = np.array(np.split(_t2n(rnn_hidden_state), 1))
|
60 |
+
|
61 |
+
re_action = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
62 |
+
re_action[0][actions[0]] = 1
|
63 |
+
|
64 |
+
return re_action
|
65 |
+
|
66 |
+
agent = OpenRLAgent()
|
67 |
+
|
68 |
+
def my_controller(obs_list, action_space_list, is_act_continuous=False):
|
69 |
+
idx = obs_list['controlled_player_index'] % 11
|
70 |
+
del obs_list['controlled_player_index']
|
71 |
+
action = agent.get_action(obs_list,idx)
|
72 |
+
return action
|
73 |
+
|
74 |
+
def jidi_controller(obs_list=None):
|
75 |
+
if obs_list is None:
|
76 |
+
return
|
77 |
+
#重命名,防止加载错误
|
78 |
+
re = my_controller(obs_list,None)
|
79 |
+
assert isinstance(re,list)
|
80 |
+
assert isinstance(re[0],list)
|
81 |
+
return re
|