ShiyuHuang commited on
Commit
2322e9b
1 Parent(s): 7da2e8e

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. goal_keeper.py +1001 -0
  2. openrl_policy.py +446 -0
  3. openrl_utils.py +421 -0
  4. submission.py +81 -0
goal_keeper.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2023 The OpenRL Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # https://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # original code from https://github.com/Sarvar-Anvarov/Google-Research-Football/blob/main/gfootball.py
18
+ # modified by TARTRL team
19
+
20
+ import math
21
+ import random
22
+ import numpy as np
23
+
24
+ from functools import wraps
25
+ from enum import Enum
26
+ from typing import *
27
+
28
+
29
+
30
+ class Action(Enum):
31
+ Idle = 0
32
+ Left = 1
33
+ TopLeft = 2
34
+ Top = 3
35
+ TopRight = 4
36
+ Right = 5
37
+ BottomRight = 6
38
+ Bottom = 7
39
+ BottomLeft = 8
40
+ LongPass= 9
41
+ HighPass = 10
42
+ ShortPass = 11
43
+ Shot = 12
44
+ Sprint = 13
45
+ ReleaseDirection = 14
46
+ ReleaseSprint = 15
47
+ Slide = 16
48
+ Dribble = 17
49
+ ReleaseDribble = 18
50
+
51
+
52
+ ALL_DIRECTION_ACTIONS = [Action.Left, Action.TopLeft, Action.Top, Action.TopRight, Action.Right, Action.BottomRight, Action.Bottom, Action.BottomLeft]
53
+ ALL_DIRECTION_VECS = [(-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]
54
+
55
+ sticky_index_to_action = [
56
+ Action.Left,
57
+ Action.TopLeft,
58
+ Action.Top,
59
+ Action.TopRight,
60
+ Action.Right,
61
+ Action.BottomRight,
62
+ Action.Bottom,
63
+ Action.BottomLeft,
64
+ Action.Sprint,
65
+ Action.Dribble
66
+ ]
67
+
68
+ GOAL_BIAS = 0.01
69
+
70
+ class PlayerRole(Enum):
71
+ GoalKeeper = 0
72
+ CenterBack = 1
73
+ LeftBack = 2
74
+ RightBack = 3
75
+ DefenceMidfield = 4
76
+ CentralMidfield = 5
77
+ LeftMidfield = 6
78
+ RIghtMidfield = 7
79
+ AttackMidfield = 8
80
+ CentralFront = 9
81
+
82
+
83
+ class GameMode(Enum):
84
+ Normal = 0
85
+ KickOff = 1
86
+ GoalKick = 2
87
+ FreeKick = 3
88
+ Corner = 4
89
+ ThrowIn = 5
90
+ Penalty = 6
91
+
92
+
93
+ def human_readable_agent(agent: Callable[[Dict], Action]):
94
+ """
95
+ Decorator allowing for more human-friendly implementation of the agent function.
96
+ @human_readable_agent
97
+ def my_agent(obs):
98
+ ...
99
+ return football_action_set.action_right
100
+ """
101
+ @wraps(agent)
102
+ def agent_wrapper(obs) -> List[int]:
103
+ # Extract observations for the first (and only) player we control.
104
+ # obs = obs['players_raw'][0]
105
+ # Turn 'sticky_actions' into a set of active actions (strongly typed).
106
+ obs['sticky_actions'] = { sticky_index_to_action[nr] for nr, action in enumerate(obs['sticky_actions']) if action }
107
+ # Turn 'game_mode' into an enum.
108
+ obs['game_mode'] = GameMode(obs['game_mode'])
109
+ # In case of single agent mode, 'designated' is always equal to 'active'.
110
+ if 'designated' in obs:
111
+ del obs['designated']
112
+ # Conver players' roles to enum.
113
+ obs['left_team_roles'] = [ PlayerRole(role) for role in obs['left_team_roles'] ]
114
+ obs['right_team_roles'] = [ PlayerRole(role) for role in obs['right_team_roles'] ]
115
+
116
+ action = agent(obs)
117
+ return [action.value]
118
+
119
+ return agent_wrapper
120
+
121
+ def find_patterns(obs, player_x, player_y):
122
+ """ find list of appropriate patterns in groups of memory patterns """
123
+ for get_group in groups_of_memory_patterns:
124
+ group = get_group(obs, player_x, player_y)
125
+ if group["environment_fits"](obs, player_x, player_y):
126
+ return group["get_memory_patterns"](obs, player_x, player_y)
127
+
128
+
129
+ def get_action_of_agent(obs, player_x, player_y):
130
+ """ get action of appropriate pattern in agent's memory """
131
+ memory_patterns = find_patterns(obs, player_x, player_y)
132
+ # find appropriate pattern in list of memory patterns
133
+ for get_pattern in memory_patterns:
134
+ pattern = get_pattern(obs, player_x, player_y)
135
+ if pattern["environment_fits"](obs, player_x, player_y):
136
+ return pattern["get_action"](obs, player_x, player_y)
137
+
138
+
139
+ def get_distance(x1, y1, right_team):
140
+ """ get two-dimensional Euclidean distance, considering y size of the field """
141
+ return math.sqrt((x1 - right_team[0]) ** 2 + (y1 * 2.38 - right_team[1] * 2.38) ** 2)
142
+
143
+
144
+ def run_to_ball_bottom(obs, player_x, player_y):
145
+ """ run to the ball if it is to the bottom from player's position """
146
+ def environment_fits(obs, player_x, player_y):
147
+ """ environment fits constraints """
148
+ # ball is to the bottom from player's position
149
+ if (obs["ball"][1] > player_y and
150
+ abs(obs["ball"][0] - player_x) < 0.01):
151
+ return True
152
+ return False
153
+
154
+ def get_action(obs, player_x, player_y):
155
+ """ get action of this memory pattern """
156
+ return Action.Bottom
157
+
158
+ return {"environment_fits": environment_fits, "get_action": get_action}
159
+
160
+
161
+ def run_to_ball_bottom_left(obs, player_x, player_y):
162
+ """ run to the ball if it is to the bottom left from player's position """
163
+ def environment_fits(obs, player_x, player_y):
164
+ """ environment fits constraints """
165
+ # ball is to the bottom left from player's position
166
+ if (obs["ball"][0] < player_x and
167
+ obs["ball"][1] > player_y):
168
+ return True
169
+ return False
170
+
171
+ def get_action(obs, player_x, player_y):
172
+ """ get action of this memory pattern """
173
+ return Action.BottomLeft
174
+
175
+ return {"environment_fits": environment_fits, "get_action": get_action}
176
+
177
+
178
+ def run_to_ball_bottom_right(obs, player_x, player_y):
179
+ """ run to the ball if it is to the bottom right from player's position """
180
+ def environment_fits(obs, player_x, player_y):
181
+ """ environment fits constraints """
182
+ # ball is to the bottom right from player's position
183
+ if (obs["ball"][0] > player_x and
184
+ obs["ball"][1] > player_y):
185
+ return True
186
+ return False
187
+
188
+ def get_action(obs, player_x, player_y):
189
+ """ get action of this memory pattern """
190
+ return Action.BottomRight
191
+
192
+ return {"environment_fits": environment_fits, "get_action": get_action}
193
+
194
+
195
+ def run_to_ball_left(obs, player_x, player_y):
196
+ """ run to the ball if it is to the left from player's position """
197
+ def environment_fits(obs, player_x, player_y):
198
+ """ environment fits constraints """
199
+ # ball is to the left from player's position
200
+ if (obs["ball"][0] < player_x and
201
+ abs(obs["ball"][1] - player_y) < 0.01):
202
+ return True
203
+ return False
204
+
205
+ def get_action(obs, player_x, player_y):
206
+ """ get action of this memory pattern """
207
+ return Action.Left
208
+
209
+ return {"environment_fits": environment_fits, "get_action": get_action}
210
+
211
+
212
+ def run_to_ball_right(obs, player_x, player_y):
213
+ """ run to the ball if it is to the right from player's position """
214
+ def environment_fits(obs, player_x, player_y):
215
+ """ environment fits constraints """
216
+ # ball is to the right from player's position
217
+ if (obs["ball"][0] > player_x and
218
+ abs(obs["ball"][1] - player_y) < 0.01):
219
+ return True
220
+ return False
221
+
222
+ def get_action(obs, player_x, player_y):
223
+ """ get action of this memory pattern """
224
+ return Action.Right
225
+
226
+ return {"environment_fits": environment_fits, "get_action": get_action}
227
+
228
+
229
+ def run_to_ball_top(obs, player_x, player_y):
230
+ """ run to the ball if it is to the top from player's position """
231
+ def environment_fits(obs, player_x, player_y):
232
+ """ environment fits constraints """
233
+ # ball is to the top from player's position
234
+ if (obs["ball"][1] < player_y and
235
+ abs(obs["ball"][0] - player_x) < 0.01):
236
+ return True
237
+ return False
238
+
239
+ def get_action(obs, player_x, player_y):
240
+ """ get action of this memory pattern """
241
+ return Action.Top
242
+
243
+ return {"environment_fits": environment_fits, "get_action": get_action}
244
+
245
+
246
+ def run_to_ball_top_left(obs, player_x, player_y):
247
+ """ run to the ball if it is to the top left from player's position """
248
+ def environment_fits(obs, player_x, player_y):
249
+ """ environment fits constraints """
250
+ # ball is to the top left from player's position
251
+ if (obs["ball"][0] < player_x and
252
+ obs["ball"][1] < player_y):
253
+ return True
254
+ return False
255
+
256
+ def get_action(obs, player_x, player_y):
257
+ """ get action of this memory pattern """
258
+ return Action.TopLeft
259
+
260
+ return {"environment_fits": environment_fits, "get_action": get_action}
261
+
262
+
263
+ def run_to_ball_top_right(obs, player_x, player_y):
264
+ """ run to the ball if it is to the top right from player's position """
265
+ def environment_fits(obs, player_x, player_y):
266
+ """ environment fits constraints """
267
+ # ball is to the top right from player's position
268
+ if (obs["ball"][0] > player_x and
269
+ obs["ball"][1] < player_y):
270
+ return True
271
+ return False
272
+
273
+ def get_action(obs, player_x, player_y):
274
+ """ get action of this memory pattern """
275
+ return Action.TopRight
276
+
277
+ return {"environment_fits": environment_fits, "get_action": get_action}
278
+
279
+
280
+ def idle(obs, player_x, player_y):
281
+ """ do nothing, release all sticky actions """
282
+ def environment_fits(obs, player_x, player_y):
283
+ """ environment fits constraints """
284
+ return True
285
+
286
+ def get_action(obs, player_x, player_y):
287
+ """ get action of this memory pattern """
288
+ return Action.Idle
289
+
290
+ return {"environment_fits": environment_fits, "get_action": get_action}
291
+
292
+
293
+ def start_sprinting(obs, player_x, player_y):
294
+ """ make sure player is sprinting """
295
+ def environment_fits(obs, player_x, player_y):
296
+ """ environment fits constraints """
297
+ if Action.Sprint not in obs["sticky_actions"]:
298
+ return True
299
+ return False
300
+
301
+ def get_action(obs, player_x, player_y):
302
+ """ get action of this memory pattern """
303
+ if Action.Dribble in obs['sticky_actions']:
304
+ return Action.ReleaseDribble
305
+ return Action.Sprint
306
+
307
+ return {"environment_fits": environment_fits, "get_action": get_action}
308
+
309
+
310
+ def corner(obs, player_x, player_y):
311
+ """ perform a shot in corner game mode """
312
+ def environment_fits(obs, player_x, player_y):
313
+ """ environment fits constraints """
314
+ # it is corner game mode
315
+ if obs['game_mode'] == GameMode.Corner:
316
+ return True
317
+ return False
318
+
319
+ def get_action(obs, player_x, player_y):
320
+ """ get action of this memory pattern """
321
+ if player_y > 0:
322
+ if Action.TopRight not in obs["sticky_actions"]:
323
+ return Action.TopRight
324
+ else:
325
+ if Action.BottomRight not in obs["sticky_actions"]:
326
+ return Action.BottomRight
327
+ return Action.HighPass
328
+
329
+ return {"environment_fits": environment_fits, "get_action": get_action}
330
+
331
+
332
+ def free_kick(obs, player_x, player_y):
333
+ """ perform a high pass or a shot in free kick game mode """
334
+ def environment_fits(obs, player_x, player_y):
335
+ """ environment fits constraints """
336
+ # it is free kick game mode
337
+ if obs['game_mode'] == GameMode.FreeKick:
338
+ return True
339
+ return False
340
+
341
+ def get_action(obs, player_x, player_y):
342
+ """ get action of this memory pattern """
343
+ # shot if player close to goal
344
+ if player_x > 0.5:
345
+ if player_y > 0:
346
+ if Action.TopRight not in obs["sticky_actions"]:
347
+ return Action.TopRight
348
+ else:
349
+ if Action.BottomRight not in obs["sticky_actions"]:
350
+ return Action.BottomRight
351
+ return Action.Shot
352
+ # high pass if player far from goal
353
+ else:
354
+ if player_y > 0:
355
+ if Action.BottomRight not in obs["sticky_actions"]:
356
+ return Action.BottomRight
357
+ else:
358
+ if Action.TopRight not in obs['sticky_actions']:
359
+ return Action.TopRight
360
+ return Action.ShortPass
361
+
362
+ return {"environment_fits": environment_fits, "get_action": get_action}
363
+
364
+
365
+ def goal_kick(obs, player_x, player_y):
366
+ """ perform a short pass in goal kick game mode """
367
+ def environment_fits(obs, player_x, player_y):
368
+ """ environment fits constraints """
369
+ # it is goal kick game mode
370
+ if obs['game_mode'] == GameMode.GoalKick:
371
+ return True
372
+ return False
373
+
374
+ def get_action(obs, player_x, player_y):
375
+ """ get action of this memory pattern """
376
+ if Action.BottomRight not in obs["sticky_actions"]:
377
+ return Action.BottomRight
378
+ return Action.ShortPass
379
+
380
+ return {"environment_fits": environment_fits, "get_action": get_action}
381
+
382
+
383
+ def kick_off(obs, player_x, player_y):
384
+ """ perform a short pass in kick off game mode """
385
+ def environment_fits(obs, player_x, player_y):
386
+ """ environment fits constraints """
387
+ # it is kick off game mode
388
+ if obs['game_mode'] == GameMode.KickOff:
389
+ return True
390
+ return False
391
+
392
+ def get_action(obs, player_x, player_y):
393
+ """ get action of this memory pattern """
394
+ if player_y > 0:
395
+ if Action.Top not in obs["sticky_actions"]:
396
+ return Action.Top
397
+ else:
398
+ if Action.Bottom not in obs["sticky_actions"]:
399
+ return Action.Bottom
400
+ return Action.ShortPass
401
+
402
+ return {"environment_fits": environment_fits, "get_action": get_action}
403
+
404
+
405
+ def penalty(obs, player_x, player_y):
406
+ """ perform a shot in penalty game mode """
407
+ def environment_fits(obs, player_x, player_y):
408
+ """ environment fits constraints """
409
+ # it is penalty game mode
410
+ if obs['game_mode'] == GameMode.Penalty:
411
+ return True
412
+ return False
413
+
414
+ def get_action(obs, player_x, player_y):
415
+ """ get action of this memory pattern """
416
+ if (random.random() < 0.5 and
417
+ Action.TopRight not in obs["sticky_actions"] and
418
+ Action.BottomRight not in obs["sticky_actions"]):
419
+ return Action.TopRight
420
+ else:
421
+ if Action.BottomRight not in obs["sticky_actions"]:
422
+ return Action.BottomRight
423
+ return Action.Shot
424
+
425
+ return {"environment_fits": environment_fits, "get_action": get_action}
426
+
427
+ def throw_in(obs, player_x, player_y):
428
+ """ perform a short pass in throw in game mode """
429
+ def environment_fits(obs, player_x, player_y):
430
+ """ environment fits constraints """
431
+ # it is throw in game mode
432
+ if obs['game_mode'] == GameMode.ThrowIn:
433
+ return True
434
+ return False
435
+
436
+ def get_action(obs, player_x, player_y):
437
+ """ get action of this memory pattern """
438
+ if Action.Right not in obs["sticky_actions"]:
439
+ return Action.Right
440
+ return Action.ShortPass
441
+
442
+ return {"environment_fits": environment_fits, "get_action": get_action}
443
+
444
+
445
+ def defence_memory_patterns(obs, player_x, player_y):
446
+ """ group of memory patterns for environments in which opponent's team has the ball """
447
+ def environment_fits(obs, player_x, player_y):
448
+ """ environment fits constraints """
449
+ # player don't have the ball
450
+ if obs["ball_owned_team"] != 0:
451
+ return True
452
+ return False
453
+
454
+ def get_memory_patterns(obs, player_x, player_y):
455
+ """ get list of memory patterns """
456
+ # shift ball position
457
+ obs["ball"][0] += obs["ball_direction"][0] * 7
458
+ obs["ball"][1] += obs["ball_direction"][1] * 3
459
+ # if opponent has the ball and is far from y axis center
460
+ if abs(obs["ball"][1]) > 0.07 and obs["ball_owned_team"] == 1:
461
+ obs["ball"][0] -= 0.01
462
+ if obs["ball"][1] > 0:
463
+ obs["ball"][1] -= 0.01
464
+ else:
465
+ obs["ball"][1] += 0.01
466
+
467
+ memory_patterns = [
468
+ start_sprinting,
469
+ run_to_ball_right,
470
+ run_to_ball_left,
471
+ run_to_ball_bottom,
472
+ run_to_ball_top,
473
+ run_to_ball_top_right,
474
+ run_to_ball_top_left,
475
+ run_to_ball_bottom_right,
476
+ run_to_ball_bottom_left,
477
+ idle
478
+ ]
479
+ return memory_patterns
480
+
481
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
482
+
483
+ def goalkeeper_memory_patterns(obs, player_x, player_y):
484
+ """ group of memory patterns for goalkeeper """
485
+ def environment_fits(obs, player_x, player_y):
486
+ """ environment fits constraints """
487
+ # player is a goalkeeper have the ball
488
+ if (obs["ball_owned_player"] == obs["active"] and
489
+ obs["ball_owned_team"] == 0 and
490
+ obs["ball_owned_player"] == 0):
491
+ return True
492
+ return False
493
+
494
+ def get_memory_patterns(obs, player_x, player_y):
495
+ """ get list of memory patterns """
496
+ memory_patterns = [
497
+ long_pass_forward,
498
+ idle
499
+ ]
500
+ return memory_patterns
501
+
502
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
503
+
504
+
505
+ def offence_memory_patterns(obs, player_x, player_y):
506
+ """ group of memory patterns for environments in which player's team has the ball """
507
+ def environment_fits(obs, player_x, player_y):
508
+ """ environment fits constraints """
509
+ # player have the ball
510
+ if obs["ball_owned_player"] == obs["active"] and obs["ball_owned_team"] == 0:
511
+ return True
512
+ return False
513
+
514
+ def get_memory_patterns(obs, player_x, player_y):
515
+ """ get list of memory patterns """
516
+ memory_patterns = [
517
+ close_to_goalkeeper_shot,
518
+ spot_shot,
519
+ cross,
520
+ long_pass_forward,
521
+ keep_the_ball,
522
+ idle
523
+ ]
524
+ return memory_patterns
525
+
526
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
527
+
528
+
529
+ def other_memory_patterns(obs, player_x, player_y):
530
+ """ group of memory patterns for all other environments """
531
+ def environment_fits(obs, player_x, player_y):
532
+ """ environment fits constraints """
533
+ return True
534
+
535
+ def get_memory_patterns(obs, player_x, player_y):
536
+ """ get list of memory patterns """
537
+ memory_patterns = [
538
+ idle
539
+ ]
540
+ return memory_patterns
541
+
542
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
543
+
544
+ def special_game_modes_memory_patterns(obs, player_x, player_y):
545
+ """ group of memory patterns for special game mode environments """
546
+ def environment_fits(obs, player_x, player_y):
547
+ """ environment fits constraints """
548
+ # if game mode is not normal
549
+ if obs['game_mode'] != GameMode.Normal:
550
+ return True
551
+ return False
552
+
553
+ def get_memory_patterns(obs, player_x, player_y):
554
+ """ get list of memory patterns """
555
+ memory_patterns = [
556
+ corner,
557
+ free_kick,
558
+ goal_kick,
559
+ kick_off,
560
+ penalty,
561
+ throw_in,
562
+ idle
563
+ ]
564
+ return memory_patterns
565
+
566
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
567
+
568
+
569
+ def special_spot_shot(obs, player_x, player_y):
570
+ """ group of memory patterns for special game mode environments """
571
+ def environment_fits(obs, player_x, player_y):
572
+ """ environment fits constraints """
573
+ # if game mode is not normal
574
+ if player_x > 0.8 and abs(player_y) < 0.21:
575
+ return True
576
+ return False
577
+
578
+ def get_memory_patterns(obs, player_x, player_y):
579
+ """ get list of memory patterns """
580
+ memory_patterns = [
581
+ shot,
582
+ idle
583
+ ]
584
+ return memory_patterns
585
+
586
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
587
+
588
+
589
+ def own_goal(obs, player_x, player_y):
590
+ """ group of memory patterns for special game mode environments """
591
+ def environment_fits(obs, player_x, player_y):
592
+ """ environment fits constraints """
593
+ # if game mode is not normal
594
+ if player_x < -0.9 and player_y:
595
+ return True
596
+ return False
597
+
598
+ def get_memory_patterns(obs, player_x, player_y):
599
+ """ get list of memory patterns """
600
+ memory_patterns = [
601
+ own_goal_2
602
+ ]
603
+ return memory_patterns
604
+
605
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
606
+
607
+ def get_best_direction(obs, target_direction):
608
+ active_position = obs["left_team"][obs["active"]]
609
+ relative_goal_position = np.array(target_direction) - active_position
610
+ all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
611
+ best_direction = np.argmax([np.dot(relative_goal_position, v) for v in all_directions_vecs])
612
+ return ALL_DIRECTION_ACTIONS[best_direction]
613
+
614
+ def get_distance2ball(obs):
615
+ return np.linalg.norm(obs["ball"][:2] - obs["left_team"][obs['active']])
616
+
617
+ def get_target2line(obs):
618
+ active_position = obs["left_team"][obs["active"]]
619
+ ball_x, ball_y = obs['ball'][0], obs['ball'][1]
620
+ distance2goal = ((ball_x + 1) ** 2 + ball_y ** 2) ** 0.5 + 1e-5
621
+ cos_theta = (ball_x + 1) / distance2goal
622
+ sin_theta = ball_y / distance2goal
623
+ target_pos = np.array([0.03 * cos_theta - 1, 0.03 * sin_theta])
624
+ return target_pos
625
+
626
+ def already_near_goal(obs, player_x, player_y):
627
+ """ do nothing, release all sticky actions """
628
+ def environment_fits(obs, player_x, player_y):
629
+ """ environment fits constraints """
630
+ active_position = obs["left_team"][obs["active"]]
631
+ relative_goal_position = np.array([-1 + GOAL_BIAS, 0]) - active_position
632
+ distance2goal = np.linalg.norm(relative_goal_position)
633
+ if distance2goal < 0.02:
634
+ return True
635
+ return False
636
+
637
+ def get_action(obs, player_x, player_y):
638
+ """ get action of this memory pattern """
639
+ # print(obs["sticky_actions"])
640
+ if Action.Sprint in obs["sticky_actions"]:
641
+ return Action.ReleaseSprint
642
+ if Action.Dribble in obs["sticky_actions"]:
643
+ return Action.ReleaseDribble
644
+ if len(obs["sticky_actions"]) > 0:
645
+ return Action.ReleaseDirection
646
+ return Action.Idle
647
+
648
+ return {"environment_fits": environment_fits, "get_action": get_action}
649
+
650
+ def already_in_line(obs, player_x, player_y):
651
+ """ do nothing, release all sticky actions """
652
+ def environment_fits(obs, player_x, player_y):
653
+ """ environment fits constraints """
654
+
655
+ target_pos = get_target2line(obs)
656
+ distance2goal = np.linalg.norm(target_pos - obs['left_team'][obs['active']])
657
+ if distance2goal < 0.02:
658
+ return True
659
+ return False
660
+
661
+ def get_action(obs, player_x, player_y):
662
+ """ get action of this memory pattern """
663
+ # print(obs["sticky_actions"])
664
+ if Action.Sprint in obs["sticky_actions"]:
665
+ return Action.ReleaseSprint
666
+ if Action.Dribble in obs["sticky_actions"]:
667
+ return Action.ReleaseDribble
668
+ if len(obs["sticky_actions"]) > 0:
669
+ return Action.ReleaseDirection
670
+ return Action.Idle
671
+
672
+ return {"environment_fits": environment_fits, "get_action": get_action}
673
+
674
+ def run_to_goal(obs, player_x, player_y):
675
+ def environment_fits(obs, player_x, player_y):
676
+ """ environment fits constraints """
677
+ return True
678
+
679
+ def get_action(obs, player_x, player_y):
680
+ # active_position = obs["left_team"][obs["active"]]
681
+ # relative_goal_position = np.array([-1 + GOAL_BIAS, 0]) - active_position
682
+ # all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
683
+ # best_direction = np.argmax([np.dot(relative_goal_position, v) for v in all_directions_vecs])
684
+ # return ALL_DIRECTION_ACTIONS[best_direction]
685
+ return get_best_direction(obs, [-1 + GOAL_BIAS, 0])
686
+
687
+ return {"environment_fits": environment_fits, "get_action": get_action}
688
+
689
+ def run_to_line(obs, player_x, player_y):
690
+ def environment_fits(obs, player_x, player_y):
691
+ """ environment fits constraints """
692
+ return True
693
+
694
+ def get_action(obs, player_x, player_y):
695
+ target_pos = get_target2line(obs)
696
+ return get_best_direction(obs, target_pos)
697
+
698
+ return {"environment_fits": environment_fits, "get_action": get_action}
699
+
700
+ def goal_keeper_far_pattern(obs, player_x, player_y):
701
+ def environment_fits(obs, player_x, player_y):
702
+ """ environment fits constraints """
703
+ # player have the ball
704
+ if (obs["active"] == 0):
705
+ active_position = obs["left_team"][0]
706
+ relative_ball_position = obs["ball"][:2] - active_position
707
+ distance2ball = np.linalg.norm(relative_ball_position)
708
+ if distance2ball > 0.5 or (obs['ball_owned_team'] == 0 and obs['ball_owned_player'] != 0):
709
+ return True
710
+ if active_position[0] > -0.7 or abs(active_position[1]) > 0.25:
711
+ for teammate_pos in obs['left_team'][1:]:
712
+ teammate_dis = np.linalg.norm(obs["ball"][:2] - teammate_pos)
713
+ if teammate_dis < distance2ball:
714
+ return True
715
+ return False
716
+
717
+ def get_memory_patterns(obs, player_x, player_y):
718
+ """ get list of memory patterns """
719
+ memory_patterns = [
720
+ already_near_goal,
721
+ start_sprinting,
722
+ run_to_goal
723
+ ]
724
+ return memory_patterns
725
+
726
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
727
+
728
+ def ball_distance_2_5(obs, player_x, player_y):
729
+ def environment_fits(obs, player_x, player_y):
730
+ """ environment fits constraints """
731
+ # player have the ball
732
+ if (obs["active"] == 0 and obs['ball_owned_team'] != 0):
733
+ distance2ball = get_distance2ball(obs)
734
+ if distance2ball <= 0.5 and distance2ball >= 0.2:
735
+ return True
736
+ return False
737
+
738
+ def get_memory_patterns(obs, player_x, player_y):
739
+ """ get list of memory patterns """
740
+ memory_patterns = [
741
+ already_in_line,
742
+ start_sprinting,
743
+ run_to_line
744
+ ]
745
+ return memory_patterns
746
+
747
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
748
+
749
+ def ball_distance_close(obs, player_x, player_y):
750
+ def environment_fits(obs, player_x, player_y):
751
+ """ environment fits constraints """
752
+ # player have the ball
753
+ if (obs["active"] == 0 and obs['ball_owned_team'] != 0):
754
+ distance2ball = get_distance2ball(obs)
755
+ if distance2ball < 0.25:
756
+ return True
757
+ return False
758
+
759
+ def get_memory_patterns(obs, player_x, player_y):
760
+ """ get list of memory patterns """
761
+ memory_patterns = [
762
+ shot
763
+ ]
764
+ return memory_patterns
765
+
766
+ return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
767
+
768
+ # list of groups of memory patterns
769
+ groups_of_memory_patterns = [
770
+ goal_keeper_far_pattern, # 安全
771
+ goalkeeper_memory_patterns, # 守门员持球
772
+ # special_spot_shot, # 射门 进不去
773
+ special_game_modes_memory_patterns, # 特殊game mode
774
+ ball_distance_2_5,
775
+ ball_distance_close,
776
+ # own_goal,
777
+ # offence_memory_patterns, # 我方持球 进不去
778
+ defence_memory_patterns,
779
+ other_memory_patterns # idle
780
+ ]
781
+
782
+
783
+ def keep_the_ball(obs, player_x, player_y):
784
+ def environment_fits(obs, player_x, player_y):
785
+ return True
786
+
787
+ def get_action(obs, player_x, player_y):
788
+ right_team, left_team = obs['right_team'], obs['left_team']
789
+ dist = [get_distance(player_x, player_y, i) for i in right_team]
790
+ closest = right_team[np.argmin(dist)]
791
+ near = [i for i in right_team if (i[0] < player_x + 0.2) and (i[0] > player_x) and (i[1] > player_y - 0.05)
792
+ and (i[1] < player_y + 0.05)]
793
+ back = [i for i in right_team if (i[0] > player_x)]
794
+ bottom_right = [i for i in left_team if (i[0] > player_x - 0.05) and (i[0] < player_x + 0.2) and (i[1] < player_y + 0.2) and
795
+ (i[1] > player_y)]
796
+ top_right = [i for i in left_team if (i[0] > player_x - 0.05) and (i[0] < player_x + 0.2) and (i[1] > player_y - 0.2) and
797
+ (i[1] < player_y)]
798
+ bottom_left = [i for i in left_team if (i[0] < player_x) and (i[0] > player_x - 0.2) and (i[1] < player_y + 0.2) and
799
+ (i[1] > player_y)]
800
+ top_left = [i for i in left_team if (i[0] < player_x) and (i[0] > player_x - 0.2) and (i[1] > player_y - 0.2) and
801
+ (i[1] < player_y)]
802
+
803
+
804
+ if len(near) == 0:
805
+ return Action.Right
806
+
807
+ if player_y > 0:
808
+ if player_y > 0.35:
809
+ return Action.Right
810
+ if len(bottom_right) > 0:
811
+ if Action.BottomRight not in obs['sticky_actions']:
812
+ return Action.BottomRight
813
+ return Action.ShortPass
814
+ return Action.BottomRight
815
+
816
+ if player_y < 0:
817
+ if player_y < -0.35:
818
+ return Action.Right
819
+ if len(top_right) > 0:
820
+ if Action.TopRight not in obs['sticky_actions']:
821
+ return Action.TopRight
822
+ return Action.ShortPass
823
+ return Action.TopRight
824
+
825
+ return {'environment_fits': environment_fits, 'get_action': get_action}
826
+
827
+
828
+ def spot_shot(obs, player_x, player_y):
829
+ """ shot if close to the goalkeeper """
830
+ def environment_fits(obs, player_x, player_y):
831
+ """ environment fits constraints """
832
+ # shoot if in spotted location
833
+ if player_x > 0.75 and abs(player_y) < 0.21:
834
+ return True
835
+ return False
836
+
837
+
838
+ def get_action(obs, player_x, player_y):
839
+ """ get action of this memory pattern """
840
+ if player_y >= 0:
841
+ if Action.TopRight not in obs["sticky_actions"]:
842
+ return Action.TopRight
843
+ else:
844
+ if Action.BottomRight not in obs["sticky_actions"]:
845
+ return Action.BottomRight
846
+ return Action.Shot
847
+
848
+ return {"environment_fits": environment_fits, "get_action": get_action}
849
+
850
+
851
+ def cross(obs, player_x, player_y):
852
+ def environment_fits(obs, player_x, player_y):
853
+ if player_x > 0.7 and abs(player_y) > 0.21:
854
+ return True
855
+ return False
856
+
857
+ def get_action(obs, player_x, player_y):
858
+
859
+ if player_x > 0.88:
860
+ if player_y > 0:
861
+ if Action.Top not in obs['sticky_actions']:
862
+ return Action.Top
863
+ else:
864
+ if Action.Bottom not in obs['sticky_actions']:
865
+ return Action.Bottom
866
+ return Action.HighPass
867
+
868
+ if player_x > 0.9:
869
+ if (Action.Right in obs['sticky_actions'] or
870
+ Action.TopRight in obs['sticky_actions'] or
871
+ Action.BottomRight in obs['sticky_actions']):
872
+ return Action.ReleaseDirection
873
+ if Action.Right not in obs['sticky_actions']:
874
+ if player_y > 0:
875
+ if Action.Top not in obs['sticky_actions']:
876
+ return Action.Top
877
+ if player_y < 0:
878
+ if Action.Bottom not in obs['sticky_actions']:
879
+ return Action.Bottom
880
+ return Action.HighPass
881
+
882
+ return {"environment_fits": environment_fits, "get_action": get_action}
883
+
884
+
885
+ def close_to_goalkeeper_shot(obs, player_x, player_y):
886
+ """ shot if close to the goalkeeper """
887
+ def environment_fits(obs, player_x, player_y):
888
+ """ environment fits constraints """
889
+ goalkeeper_x = obs["right_team"][0][0] + obs["right_team_direction"][0][0] * 13
890
+ goalkeeper_y = obs["right_team"][0][1] + obs["right_team_direction"][0][1] * 13
891
+ goalkeeper = [goalkeeper_x,goalkeeper_y]
892
+
893
+ if get_distance(player_x, player_y, goalkeeper) < 0.25:
894
+ return True
895
+ return False
896
+
897
+ def get_action(obs, player_x, player_y):
898
+ """ get action of this memory pattern """
899
+ if player_y >= 0:
900
+ if Action.TopRight not in obs["sticky_actions"]:
901
+ return Action.TopRight
902
+ else:
903
+ if Action.BottomRight not in obs["sticky_actions"]:
904
+ return Action.BottomRight
905
+ return Action.Shot
906
+
907
+ return {"environment_fits": environment_fits, "get_action": get_action}
908
+
909
+
910
+ def long_pass_forward(obs, player_x, player_y):
911
+ """ perform a high pass, if far from opponent's goal """
912
+ def environment_fits(obs, player_x, player_y):
913
+ """ environment fits constraints """
914
+ right_team = obs["right_team"][1:]
915
+ # player have the ball and is far from opponent's goal
916
+ if player_x < -0.4:
917
+ return True
918
+ return False
919
+
920
+ def get_action(obs, player_x, player_y):
921
+ """ get action of this memory pattern """
922
+ right_team, left_team = obs['right_team'], obs['left_team']
923
+ dist = [get_distance(player_x, player_y, i) for i in right_team]
924
+ closest = right_team[np.argmin(dist)]
925
+
926
+
927
+ if abs(player_y) > 0.22:
928
+ if Action.Right not in obs["sticky_actions"]:
929
+ return Action.Right
930
+ return Action.HighPass
931
+
932
+ if np.min(dist) > 0.4:
933
+ if player_y > 0:
934
+ return Action.Bottom
935
+ else:
936
+ return Action.Top
937
+
938
+ if np.min(dist) < 0.4 and np.min(dist) > 0.2:
939
+ if player_y < 0:
940
+ return Action.TopRight
941
+ else:
942
+ return Action.BottomRight
943
+
944
+ if np.min(dist) < 0.2:
945
+ if Action.Right not in obs['sticky_actions']:
946
+ return Action.Right
947
+ return Action.HighPass
948
+
949
+ return {"environment_fits": environment_fits, "get_action": get_action}
950
+
951
+ def shot(obs, player_x, player_y):
952
+ def environment_fits(obs, player_x, player_y):
953
+ return True
954
+
955
+ def get_action(obs, player_x, player_y):
956
+ # if player_y > 0:
957
+ # if Action.TopRight not in obs['sticky_actions']:
958
+ # return Action.TopRight
959
+ # else:
960
+ # if Action.BottomRight not in obs['sticky_actions']:
961
+ # return Action.BottomRight
962
+ return Action.Shot
963
+
964
+ return {"environment_fits": environment_fits, "get_action": get_action}
965
+
966
+
967
+ def own_goal_2(obs, player_x, player_y):
968
+ def environment_fits(obs, player_x, player_y):
969
+ return True
970
+
971
+ def get_action(obs, player_x, player_y):
972
+ return Action.Shot
973
+
974
+ return {"environment_fits": environment_fits, "get_action": get_action}
975
+
976
+
977
+ # @human_readable_agent wrapper modifies raw observations
978
+ # provided by the environment:
979
+ # https://github.com/google-research/football/blob/master/gfootball/doc/observation.md#raw-observations
980
+ # into a form easier to work with by humans.
981
+ # Following modifications are applied:
982
+ # - Action, PlayerRole and GameMode enums are introduced.
983
+ # - 'sticky_actions' are turned into a set of active actions (Action enum)
984
+ # see usage example below.
985
+ # - 'game_mode' is turned into GameMode enum.
986
+ # - 'designated' field is removed, as it always equals to 'active'
987
+ # when a single player is controlled on the team.
988
+ # - 'left_team_roles'/'right_team_roles' are turned into PlayerRole enums.
989
+ # - Action enum is to be returned by the agent function.
990
+ @human_readable_agent
991
+ def agent_get_action(obs):
992
+ """ Ole ole ole ole """
993
+ # dictionary for Memory Patterns data
994
+ obs["memory_patterns"] = {}
995
+ # We always control left team (observations and actions
996
+ # are mirrored appropriately by the environment).
997
+ controlled_player_pos = obs["left_team"][obs["active"]]
998
+ # get action of appropriate pattern in agent's memory
999
+ action = get_action_of_agent(obs, controlled_player_pos[0], controlled_player_pos[1])
1000
+ # return action
1001
+ return action
openrl_policy.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2023 The OpenRL Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # https://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import numpy as np
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.distributions import Categorical
22
+
23
+ import gym
24
+
25
+ def check(input):
26
+ output = torch.from_numpy(input) if type(input) == np.ndarray else input
27
+ return output
28
+
29
+ class FcEncoder(nn.Module):
30
+ def __init__(self, fc_num, input_size, output_size):
31
+ super(FcEncoder, self).__init__()
32
+ self.first_mlp = nn.Sequential(
33
+ nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)
34
+ )
35
+ self.mlp = nn.Sequential()
36
+ for _ in range(fc_num - 1):
37
+ self.mlp.append(nn.Sequential(
38
+ nn.Linear(output_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)
39
+ ))
40
+
41
+ def forward(self, x):
42
+ output = self.first_mlp(x)
43
+ return self.mlp(output)
44
+
45
+ def init(module, weight_init, bias_init, gain=1):
46
+ weight_init(module.weight.data, gain=gain)
47
+ if module.bias is not None:
48
+ bias_init(module.bias.data)
49
+ return module
50
+
51
+
52
+ class FixedCategorical(torch.distributions.Categorical):
53
+ def sample(self):
54
+ return super().sample().unsqueeze(-1)
55
+
56
+ def log_probs(self, actions):
57
+ return (
58
+ super()
59
+ .log_prob(actions.squeeze(-1))
60
+ .view(actions.size(0), -1)
61
+ .sum(-1)
62
+ .unsqueeze(-1)
63
+ )
64
+
65
+ def mode(self):
66
+ return self.probs.argmax(dim=-1, keepdim=True)
67
+
68
+ class Categorical(nn.Module):
69
+ def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
70
+ super(Categorical, self).__init__()
71
+ init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
72
+ def init_(m):
73
+ return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
74
+
75
+ self.linear = init_(nn.Linear(num_inputs, num_outputs))
76
+
77
+ def forward(self, x, available_actions=None):
78
+ x = self.linear(x)
79
+ if available_actions is not None:
80
+ x[available_actions == 0] = -1e10
81
+ return FixedCategorical(logits=x)
82
+
83
+
84
+ class AddBias(nn.Module):
85
+ def __init__(self, bias):
86
+ super(AddBias, self).__init__()
87
+ self._bias = nn.Parameter(bias.unsqueeze(1))
88
+
89
+ def forward(self, x):
90
+ if x.dim() == 2:
91
+ bias = self._bias.t().view(1, -1)
92
+ else:
93
+ bias = self._bias.t().view(1, -1, 1, 1)
94
+
95
+ return x + bias
96
+
97
+ class ACTLayer(nn.Module):
98
+ def __init__(self, action_space, inputs_dim, use_orthogonal, gain):
99
+ super(ACTLayer, self).__init__()
100
+ self.multidiscrete_action = False
101
+ self.continuous_action = False
102
+ self.mixed_action = False
103
+
104
+ action_dim = action_space.n
105
+ self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)
106
+
107
+
108
+
109
+ def forward(self, x, available_actions=None, deterministic=False):
110
+ if self.mixed_action :
111
+ actions = []
112
+ action_log_probs = []
113
+ for action_out in self.action_outs:
114
+ action_logit = action_out(x)
115
+ action = action_logit.mode() if deterministic else action_logit.sample()
116
+ action_log_prob = action_logit.log_probs(action)
117
+ actions.append(action.float())
118
+ action_log_probs.append(action_log_prob)
119
+
120
+ actions = torch.cat(actions, -1)
121
+ action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
122
+
123
+ elif self.multidiscrete_action:
124
+ actions = []
125
+ action_log_probs = []
126
+ for action_out in self.action_outs:
127
+ action_logit = action_out(x)
128
+ action = action_logit.mode() if deterministic else action_logit.sample()
129
+ action_log_prob = action_logit.log_probs(action)
130
+ actions.append(action)
131
+ action_log_probs.append(action_log_prob)
132
+
133
+ actions = torch.cat(actions, -1)
134
+ action_log_probs = torch.cat(action_log_probs, -1)
135
+
136
+ elif self.continuous_action:
137
+ action_logits = self.action_out(x)
138
+ actions = action_logits.mode() if deterministic else action_logits.sample()
139
+ action_log_probs = action_logits.log_probs(actions)
140
+
141
+ else:
142
+ action_logits = self.action_out(x, available_actions)
143
+ actions = action_logits.mode() if deterministic else action_logits.sample()
144
+ action_log_probs = action_logits.log_probs(actions)
145
+
146
+ return actions, action_log_probs
147
+
148
+ def get_probs(self, x, available_actions=None):
149
+ if self.mixed_action or self.multidiscrete_action:
150
+ action_probs = []
151
+ for action_out in self.action_outs:
152
+ action_logit = action_out(x)
153
+ action_prob = action_logit.probs
154
+ action_probs.append(action_prob)
155
+ action_probs = torch.cat(action_probs, -1)
156
+ elif self.continuous_action:
157
+ action_logits = self.action_out(x)
158
+ action_probs = action_logits.probs
159
+ else:
160
+ action_logits = self.action_out(x, available_actions)
161
+ action_probs = action_logits.probs
162
+
163
+ return action_probs
164
+
165
+ def evaluate_actions(self, x, action, available_actions=None, active_masks=None, get_probs=False):
166
+ if self.mixed_action:
167
+ a, b = action.split((2, 1), -1)
168
+ b = b.long()
169
+ action = [a, b]
170
+ action_log_probs = []
171
+ dist_entropy = []
172
+ for action_out, act in zip(self.action_outs, action):
173
+ action_logit = action_out(x)
174
+ action_log_probs.append(action_logit.log_probs(act))
175
+ if active_masks is not None:
176
+ if len(action_logit.entropy().shape) == len(active_masks.shape):
177
+ dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum())
178
+ else:
179
+ dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum())
180
+ else:
181
+ dist_entropy.append(action_logit.entropy().mean())
182
+
183
+ action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
184
+ dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01
185
+
186
+ elif self.multidiscrete_action:
187
+ action = torch.transpose(action, 0, 1)
188
+ action_log_probs = []
189
+ dist_entropy = []
190
+ for action_out, act in zip(self.action_outs, action):
191
+ action_logit = action_out(x)
192
+ action_log_probs.append(action_logit.log_probs(act))
193
+ if active_masks is not None:
194
+ dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())
195
+ else:
196
+ dist_entropy.append(action_logit.entropy().mean())
197
+
198
+ action_log_probs = torch.cat(action_log_probs, -1) # ! could be wrong
199
+ dist_entropy = torch.tensor(dist_entropy).mean()
200
+
201
+ elif self.continuous_action:
202
+ action_logits = self.action_out(x)
203
+ action_log_probs = action_logits.log_probs(action)
204
+ act_entropy = action_logits.entropy()
205
+ # import pdb;pdb.set_trace()
206
+ if active_masks is not None:
207
+ dist_entropy = (act_entropy*active_masks).sum()/active_masks.sum()
208
+ else:
209
+ dist_entropy = act_entropy.mean()
210
+
211
+ else:
212
+ action_logits = self.action_out(x, available_actions)
213
+ action_log_probs = action_logits.log_probs(action)
214
+ if active_masks is not None:
215
+ dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()
216
+ else:
217
+ dist_entropy = action_logits.entropy().mean()
218
+ if not get_probs:
219
+ return action_log_probs, dist_entropy
220
+ else:
221
+ return action_log_probs, dist_entropy, action_logits
222
+
223
+ class RNNLayer(nn.Module):
224
+ def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal,rnn_type='gru'):
225
+ super(RNNLayer, self).__init__()
226
+ self._recurrent_N = recurrent_N
227
+ self._use_orthogonal = use_orthogonal
228
+ self.rnn_type = rnn_type
229
+ if rnn_type == 'gru':
230
+ self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
231
+ elif rnn_type == 'lstm':
232
+ self.rnn = nn.LSTM(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
233
+ else:
234
+ raise NotImplementedError(f'RNN type {rnn_type} has not been implemented.')
235
+
236
+ for name, param in self.rnn.named_parameters():
237
+ if 'bias' in name:
238
+ nn.init.constant_(param, 0)
239
+ elif 'weight' in name:
240
+ if self._use_orthogonal:
241
+ nn.init.orthogonal_(param)
242
+ else:
243
+ nn.init.xavier_uniform_(param)
244
+ self.norm = nn.LayerNorm(outputs_dim)
245
+
246
+ def rnn_forward(self, x, h):
247
+ if self.rnn_type == 'lstm':
248
+ h = torch.split(h, h.shape[-1] // 2, dim=-1)
249
+ h = (h[0].contiguous(), h[1].contiguous())
250
+ x_, h_ = self.rnn(x, h)
251
+ if self.rnn_type == 'lstm':
252
+ h_ = torch.cat(h_, -1)
253
+ return x_, h_
254
+
255
+ def forward(self, x, hxs, masks):
256
+ if x.size(0) == hxs.size(0):
257
+ x, hxs = self.rnn_forward(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous())
258
+ #x= self.gru(x.unsqueeze(0))
259
+ x = x.squeeze(0)
260
+ hxs = hxs.transpose(0, 1)
261
+ else:
262
+ # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
263
+ N = hxs.size(0)
264
+ T = int(x.size(0) / N)
265
+
266
+ # unflatten
267
+ x = x.view(T, N, x.size(1))
268
+
269
+ # Same deal with masks
270
+ masks = masks.view(T, N)
271
+
272
+ # Let's figure out which steps in the sequence have a zero for any agent
273
+ # We will always assume t=0 has a zero in it as that makes the logic cleaner
274
+ has_zeros = ((masks[1:] == 0.0)
275
+ .any(dim=-1)
276
+ .nonzero()
277
+ .squeeze()
278
+ .cpu())
279
+
280
+ # +1 to correct the masks[1:]
281
+ if has_zeros.dim() == 0:
282
+ # Deal with scalar
283
+ has_zeros = [has_zeros.item() + 1]
284
+ else:
285
+ has_zeros = (has_zeros + 1).numpy().tolist()
286
+
287
+ # add t=0 and t=T to the list
288
+ has_zeros = [0] + has_zeros + [T]
289
+
290
+ hxs = hxs.transpose(0, 1)
291
+
292
+ outputs = []
293
+ for i in range(len(has_zeros) - 1):
294
+ # We can now process steps that don't have any zeros in masks together!
295
+ # This is much faster
296
+ start_idx = has_zeros[i]
297
+ end_idx = has_zeros[i + 1]
298
+ temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous()
299
+ rnn_scores, hxs = self.rnn_forward(x[start_idx:end_idx], temp)
300
+ outputs.append(rnn_scores)
301
+
302
+ # assert len(outputs) == T
303
+ # x is a (T, N, -1) tensor
304
+ x = torch.cat(outputs, dim=0)
305
+
306
+ # flatten
307
+ x = x.reshape(T * N, -1)
308
+ hxs = hxs.transpose(0, 1)
309
+
310
+ x = self.norm(x)
311
+ return x, hxs
312
+
313
+
314
+ class InputEncoder(nn.Module):
315
+ def __init__(self):
316
+ super(InputEncoder, self).__init__()
317
+ fc_layer_num = 2
318
+ fc_output_num = 64
319
+ self.active_input_num = 87
320
+ self.ball_owner_input_num = 57
321
+ self.left_input_num = 88
322
+ self.right_input_num = 88
323
+ self.match_state_input_num = 9
324
+
325
+ self.active_encoder = FcEncoder(fc_layer_num, self.active_input_num, fc_output_num)
326
+ self.ball_owner_encoder = FcEncoder(fc_layer_num, self.ball_owner_input_num, fc_output_num)
327
+ self.left_encoder = FcEncoder(fc_layer_num, self.left_input_num, fc_output_num)
328
+ self.right_encoder = FcEncoder(fc_layer_num, self.right_input_num, fc_output_num)
329
+ self.match_state_encoder = FcEncoder(fc_layer_num, self.match_state_input_num, self.match_state_input_num)
330
+
331
+ def forward(self, x):
332
+ active_vec = x[:, :self.active_input_num]
333
+ ball_owner_vec = x[:, self.active_input_num : self.active_input_num + self.ball_owner_input_num]
334
+ left_vec = x[:, self.active_input_num + self.ball_owner_input_num : self.active_input_num + self.ball_owner_input_num + self.left_input_num]
335
+ right_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num : \
336
+ self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num]
337
+ match_state_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num:]
338
+
339
+ active_output = self.active_encoder(active_vec)
340
+ ball_owner_output = self.ball_owner_encoder(ball_owner_vec)
341
+ left_output = self.left_encoder(left_vec)
342
+ right_output = self.right_encoder(right_vec)
343
+ match_state_output = self.match_state_encoder(match_state_vec)
344
+
345
+ return torch.cat([
346
+ active_output,
347
+ ball_owner_output,
348
+ left_output,
349
+ right_output,
350
+ match_state_output
351
+ ], 1)
352
+
353
+ def get_fc(input_size, output_size):
354
+ return nn.Sequential(nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size))
355
+
356
+ class ObsEncoder(nn.Module):
357
+ def __init__(self, input_embedding_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type):
358
+ super(ObsEncoder, self).__init__()
359
+ self.input_encoder = InputEncoder() # input先过一遍input encoder
360
+ self.input_embedding = get_fc(input_embedding_size, hidden_size) # 将encoder输出进行embedding
361
+ self.rnn = RNNLayer(hidden_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type=rnn_type) # embedding输出过一遍rnn
362
+ self.after_rnn_mlp = get_fc(hidden_size, hidden_size) # 过了rnn后再过该mlp
363
+
364
+ def forward(self, obs, rnn_states, masks):
365
+ actor_features = self.input_encoder(obs)
366
+ actor_features = self.input_embedding(actor_features)
367
+ output, rnn_states = self.rnn(actor_features, rnn_states, masks)
368
+ return self.after_rnn_mlp(output), rnn_states
369
+
370
+
371
+ class PolicyNetwork(nn.Module):
372
+ def __init__(self, device=torch.device("cpu")):
373
+ super(PolicyNetwork, self).__init__()
374
+ self.tpdv = dict(dtype=torch.float32, device=device)
375
+ self.device = device
376
+ self.hidden_size = 256
377
+ self._use_policy_active_masks = True
378
+ recurrent_N = 1
379
+ use_orthogonal = True
380
+ rnn_type = 'lstm'
381
+ gain = 0.01
382
+ action_space = gym.spaces.Discrete(20)
383
+ self.action_dim = 19
384
+ input_embedding_size = 64 * 4 + 9
385
+ self.active_id_size = 1
386
+ self.id_max = 11
387
+
388
+ self.obs_encoder = ObsEncoder(input_embedding_size, self.hidden_size, recurrent_N, use_orthogonal, rnn_type)
389
+
390
+ self.predict_id = get_fc(self.hidden_size + self.action_dim, self.id_max) # 其他信息(指除了active_id外的信息)过了rnn和一层mlp后,经过该层来预测id
391
+ self.id_embedding = get_fc(self.id_max, self.id_max) # active id作为输入,输出和其他信息的feature concat
392
+
393
+ self.before_act_wrapper = FcEncoder(2, self.hidden_size + self.id_max, self.hidden_size)
394
+ self.act = ACTLayer(action_space, self.hidden_size, use_orthogonal, gain)
395
+
396
+ self.to(device)
397
+
398
+
399
+ def forward(self, obs, rnn_states, masks=np.concatenate(np.ones((1, 1, 1), dtype=np.float32)), available_actions=None, deterministic=False):
400
+ obs = check(obs).to(**self.tpdv)
401
+ if available_actions is not None:
402
+ available_actions = check(available_actions).to(**self.tpdv)
403
+ masks = check(masks).to(**self.tpdv)
404
+ rnn_states = check(rnn_states).to(**self.tpdv)
405
+
406
+ active_id = obs[:,:self.active_id_size].squeeze(1).long()
407
+ id_onehot = torch.eye(self.id_max)[active_id].to(self.device)
408
+ obs = obs[:,self.active_id_size:]
409
+
410
+ obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks)
411
+ id_output = self.id_embedding(id_onehot)
412
+ output = torch.cat([id_output, obs_output], 1)
413
+
414
+ output = self.before_act_wrapper(output)
415
+
416
+ actions, action_log_probs = self.act(output, available_actions, deterministic)
417
+ return actions, rnn_states
418
+
419
+ def eval_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):
420
+ obs = check(obs).to(**self.tpdv)
421
+ if available_actions is not None:
422
+ available_actions = check(available_actions).to(**self.tpdv)
423
+ if active_masks is not None:
424
+ active_masks = check(active_masks).to(**self.tpdv)
425
+ masks = check(masks).to(**self.tpdv)
426
+ rnn_states = check(rnn_states).to(**self.tpdv)
427
+ action = check(action).to(**self.tpdv)
428
+
429
+ id_groundtruth = obs[:,:self.active_id_size].squeeze(1).long()
430
+ id_onehot = torch.eye(self.id_max)[id_groundtruth].to(self.device)
431
+ obs = obs[:,self.active_id_size:]
432
+
433
+ obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks)
434
+ id_output = self.id_embedding(id_onehot)
435
+
436
+ action_onehot = torch.eye(self.action_dim)[action.squeeze(1).long()].to(self.device)
437
+
438
+ id_prediction = self.predict_id(torch.cat([obs_output, action_onehot], 1))
439
+ output = torch.cat([id_output, obs_output], 1)
440
+
441
+ output = self.before_act_wrapper(output)
442
+ action_log_probs, dist_entropy = self.act.evaluate_actions(output, action, available_actions,
443
+ active_masks=active_masks if self._use_policy_active_masks else None)
444
+ values = None
445
+ return action_log_probs, dist_entropy, values, id_prediction, id_groundtruth
446
+
openrl_utils.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2023 The OpenRL Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # https://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import numpy as np
18
+
19
+ # Area.
20
+ THIRD_X = 0.3
21
+ BOX_X = 0.7
22
+ MAX_X = 1.0
23
+ BOX_Y = 0.24
24
+ MAX_Y = 0.42
25
+
26
+ # Actions.
27
+ IDLE = 0
28
+ LEFT = 1
29
+ TOP_LEFT = 2
30
+ TOP = 3
31
+ TOP_RIGHT = 4
32
+ RIGHT = 5
33
+ BOTTOM_RIGHT = 6
34
+ BOTTOM = 7
35
+ BOTTOM_LEFT = 8
36
+ LONG_PASS = 9
37
+ HIGH_PASS = 10
38
+ SHORT_PASS = 11
39
+ SHOT = 12
40
+ SPRINT = 13
41
+ RELEASE_DIRECTION = 14
42
+ RELEASE_SPRINT = 15
43
+ SLIDING = 16
44
+ DRIBBLE = 17
45
+ RELEASE_DRIBBLE = 18
46
+ STICKY_LEFT = 0
47
+ STICKY_TOP_LEFT = 1
48
+ STICKY_TOP = 2
49
+ STICKY_TOP_RIGHT = 3
50
+ STICKY_RIGHT = 4
51
+ STICKY_BOTTOM_RIGHT = 5
52
+ STICKY_BOTTOM = 6
53
+ STICKY_BOTTOM_LEFT = 7
54
+
55
+ RIGHT_ACTIONS = [TOP_RIGHT, RIGHT, BOTTOM_RIGHT, TOP, BOTTOM]
56
+ LEFT_ACTIONS = [TOP_LEFT, LEFT, BOTTOM_LEFT, TOP, BOTTOM]
57
+ BOTTOM_ACTIONS = [BOTTOM_LEFT, BOTTOM, BOTTOM_RIGHT, LEFT, RIGHT]
58
+ TOP_ACTIONS = [TOP_LEFT, TOP, TOP_RIGHT, LEFT, RIGHT]
59
+ ALL_DIRECTION_ACTIONS = [LEFT, TOP_LEFT, TOP, TOP_RIGHT, RIGHT, BOTTOM_RIGHT, BOTTOM, BOTTOM_LEFT]
60
+ ALL_DIRECTION_VECS = [(-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]
61
+
62
+ def get_direction_action(available_action, sticky_actions, forbidden_action, target_action, active_direction, need_sprint):
63
+ available_action = np.zeros(19)
64
+ available_action[forbidden_action] = 0
65
+ available_action[target_action] = 1
66
+
67
+ if need_sprint:
68
+ available_action[RELEASE_SPRINT] = 0
69
+ if sticky_actions[8] == 0:
70
+ available_action = np.zeros(19)
71
+ available_action[SPRINT] = 1
72
+ else:
73
+ available_action[SPRINT] = 0
74
+ if sticky_actions[8] == 1:
75
+ available_action = np.zeros(19)
76
+ available_action[RELEASE_SPRINT] = 1
77
+ return available_action
78
+
79
+ def openrl_obs_deal(obs):
80
+
81
+ direction_x_bound = 0.03
82
+ direction_y_bound = 0.02
83
+ ball_direction_x_bound = 0.15
84
+ ball_direction_y_bound = 0.07
85
+ ball_direction_z_bound = 4
86
+ ball_rotation_x_bound = 0.0005
87
+ ball_rotation_y_bound = 0.0004
88
+ ball_rotation_z_bound = 0.015
89
+ active_id = [obs["active"]]
90
+ assert active_id[0] < 11 and active_id[0] >= 0, "active id is wrong, active id = {}".format(active_id[0])
91
+ # left team 88
92
+ left_position = np.concatenate(obs["left_team"])
93
+ left_direction = np.concatenate(obs["left_team_direction"])
94
+ left_tired_factor = obs["left_team_tired_factor"]
95
+ left_yellow_card = obs["left_team_yellow_card"]
96
+ left_red_card = ~obs["left_team_active"]
97
+ left_offside = np.zeros(11)
98
+ if obs["ball_owned_team"] == 0:
99
+ left_offside_line = max(0, obs["ball"][0], np.sort(obs["right_team"][:, 0])[-2])
100
+ left_offside = obs["left_team"][:, 0] > left_offside_line
101
+ left_offside[obs["ball_owned_player"]] = False
102
+
103
+ new_left_direction = left_direction.copy()
104
+ for counting in range(len(new_left_direction)):
105
+ new_left_direction[counting] = new_left_direction[counting] / direction_x_bound if counting % 2 == 0 else new_left_direction[counting] / direction_y_bound
106
+
107
+ left_team = np.concatenate([
108
+ left_position,
109
+ new_left_direction,
110
+ left_tired_factor,
111
+ left_yellow_card,
112
+ left_red_card,
113
+ left_offside,
114
+ ]).astype(np.float64)
115
+
116
+ # right team 88
117
+ right_position = np.concatenate(obs["right_team"])
118
+ right_direction = np.concatenate(obs["right_team_direction"])
119
+ right_tired_factor = obs["right_team_tired_factor"]
120
+ right_yellow_card = obs["right_team_yellow_card"]
121
+ right_red_card = ~obs["right_team_active"]
122
+ right_offside = np.zeros(11)
123
+ if obs["ball_owned_team"] == 1:
124
+ right_offside_line = min(0, obs["ball"][0], np.sort(obs["left_team"][:, 0])[1])
125
+ right_offside = obs["right_team"][:, 0] < right_offside_line
126
+ right_offside[obs["ball_owned_player"]] = False
127
+
128
+ new_right_direction = right_direction.copy()
129
+ for counting in range(len(new_right_direction)):
130
+ new_right_direction[counting] = new_right_direction[counting] / direction_x_bound if counting % 2 == 0 else new_right_direction[counting] / direction_y_bound
131
+
132
+ right_team = np.concatenate([
133
+ right_position,
134
+ new_right_direction,
135
+ right_tired_factor,
136
+ right_yellow_card,
137
+ right_red_card,
138
+ right_offside,
139
+ ]).astype(np.float64)
140
+
141
+ # active 18
142
+ sticky_actions = obs["sticky_actions"][:10]
143
+ active_position = obs["left_team"][obs["active"]]
144
+ active_direction = obs["left_team_direction"][obs["active"]]
145
+ active_tired_factor = left_tired_factor[obs["active"]]
146
+ active_yellow_card = left_yellow_card[obs["active"]]
147
+ active_red_card = left_red_card[obs["active"]]
148
+ active_offside = left_offside[obs["active"]]
149
+
150
+ new_active_direction = active_direction.copy()
151
+ new_active_direction[0] /= direction_x_bound
152
+ new_active_direction[1] /= direction_y_bound
153
+
154
+ active_player = np.concatenate([
155
+ sticky_actions,
156
+ active_position,
157
+ new_active_direction,
158
+ [active_tired_factor],
159
+ [active_yellow_card],
160
+ [active_red_card],
161
+ [active_offside],
162
+ ]).astype(np.float64)
163
+
164
+ # relative 69
165
+ relative_ball_position = obs["ball"][:2] - active_position
166
+ distance2ball = np.linalg.norm(relative_ball_position)
167
+ relative_left_position = obs["left_team"] - active_position
168
+ distance2left = np.linalg.norm(relative_left_position, axis=1)
169
+ relative_left_position = np.concatenate(relative_left_position)
170
+ relative_right_position = obs["right_team"] - active_position
171
+ distance2right = np.linalg.norm(relative_right_position, axis=1)
172
+ relative_right_position = np.concatenate(relative_right_position)
173
+ relative_info = np.concatenate([
174
+ relative_ball_position,
175
+ [distance2ball],
176
+ relative_left_position,
177
+ distance2left,
178
+ relative_right_position,
179
+ distance2right,
180
+ ]).astype(np.float64)
181
+
182
+ active_info = np.concatenate([active_player, relative_info]) # 87
183
+
184
+ # ball info 12
185
+ ball_owned_team = np.zeros(3)
186
+ ball_owned_team[obs["ball_owned_team"] + 1] = 1.0
187
+ new_ball_direction = obs["ball_direction"].copy()
188
+ new_ball_rotation = obs['ball_rotation'].copy()
189
+ for counting in range(len(new_ball_direction)):
190
+ if counting % 3 == 0:
191
+ new_ball_direction[counting] /= ball_direction_x_bound
192
+ new_ball_rotation[counting] /= ball_rotation_x_bound
193
+ if counting % 3 == 1:
194
+ new_ball_direction[counting] /= ball_direction_y_bound
195
+ new_ball_rotation[counting] /= ball_rotation_y_bound
196
+ if counting % 3 == 2:
197
+ new_ball_direction[counting] /= ball_direction_z_bound
198
+ new_ball_rotation[counting] /= ball_rotation_z_bound
199
+ ball_info = np.concatenate([
200
+ obs["ball"],
201
+ new_ball_direction,
202
+ ball_owned_team,
203
+ new_ball_rotation
204
+ ]).astype(np.float64)
205
+ # ball owned player 23
206
+ ball_owned_player = np.zeros(23)
207
+ if obs["ball_owned_team"] == 1: # 对手
208
+ ball_owned_player[11 + obs['ball_owned_player']] = 1.0
209
+ ball_owned_player_pos = obs['right_team'][obs['ball_owned_player']]
210
+ ball_owned_player_direction = obs["right_team_direction"][obs['ball_owned_player']]
211
+ ball_owner_tired_factor = right_tired_factor[obs['ball_owned_player']]
212
+ ball_owner_yellow_card = right_yellow_card[obs['ball_owned_player']]
213
+ ball_owner_red_card = right_red_card[obs['ball_owned_player']]
214
+ ball_owner_offside = right_offside[obs['ball_owned_player']]
215
+ elif obs["ball_owned_team"] == 0:
216
+ ball_owned_player[obs['ball_owned_player']] = 1.0
217
+ ball_owned_player_pos = obs['left_team'][obs['ball_owned_player']]
218
+ ball_owned_player_direction = obs["left_team_direction"][obs['ball_owned_player']]
219
+ ball_owner_tired_factor = left_tired_factor[obs['ball_owned_player']]
220
+ ball_owner_yellow_card = left_yellow_card[obs['ball_owned_player']]
221
+ ball_owner_red_card = left_red_card[obs['ball_owned_player']]
222
+ ball_owner_offside = left_offside[obs['ball_owned_player']]
223
+ else:
224
+ ball_owned_player[-1] = 1.0
225
+ ball_owned_player_pos = np.zeros(2)
226
+ ball_owned_player_direction = np.zeros(2)
227
+
228
+ relative_ball_owner_position = np.zeros(2)
229
+ distance2ballowner = np.zeros(1)
230
+ ball_owner_info = np.zeros(4)
231
+ if obs["ball_owned_team"] != -1:
232
+ relative_ball_owner_position = ball_owned_player_pos - active_position
233
+ distance2ballowner = [np.linalg.norm(relative_ball_owner_position)]
234
+ ball_owner_info = np.concatenate([
235
+ [ball_owner_tired_factor],
236
+ [ball_owner_yellow_card],
237
+ [ball_owner_red_card],
238
+ [ball_owner_offside]
239
+ ])
240
+
241
+ new_ball_owned_player_direction = ball_owned_player_direction.copy()
242
+ new_ball_owned_player_direction[0] /= direction_x_bound
243
+ new_ball_owned_player_direction[1] /= direction_y_bound
244
+
245
+ ball_own_active_info = np.concatenate([
246
+ ball_info, # 12
247
+ ball_owned_player, # 23
248
+ active_position, # 2
249
+ new_active_direction, # 2
250
+ [active_tired_factor], # 1
251
+ [active_yellow_card], # 1
252
+ [active_red_card], # 1
253
+ [active_offside], # 1
254
+ relative_ball_position, # 2
255
+ [distance2ball], # 1
256
+ ball_owned_player_pos, # 2
257
+ new_ball_owned_player_direction, # 2
258
+ relative_ball_owner_position, # 2
259
+ distance2ballowner, # 1
260
+ ball_owner_info # 4
261
+ ])
262
+
263
+ # match state
264
+ game_mode = np.zeros(7)
265
+ game_mode[obs["game_mode"]] = 1.0
266
+ goal_diff_ratio = (obs["score"][0] - obs["score"][1]) / 5
267
+ steps_left_ratio = obs["steps_left"] / 3001
268
+ match_state = np.concatenate([
269
+ game_mode,
270
+ [goal_diff_ratio],
271
+ [steps_left_ratio],
272
+ ]).astype(np.float64)
273
+
274
+ # available action
275
+ available_action = np.ones(19)
276
+ available_action[IDLE] = 0
277
+ available_action[RELEASE_DIRECTION] = 0
278
+ should_left = False
279
+
280
+
281
+ if obs["game_mode"] == 0:
282
+ active_x = active_position[0]
283
+ counting_right_enemy_num = 0
284
+ counting_right_teammate_num = 0
285
+ counting_left_teammate_num = 0
286
+ for enemy_pos in obs["right_team"][1:]:
287
+ if active_x < enemy_pos[0]:
288
+ counting_right_enemy_num += 1
289
+ for teammate_pos in obs["left_team"][1:]:
290
+ if active_x < teammate_pos[0]:
291
+ counting_right_teammate_num += 1
292
+ if active_x > teammate_pos[0]:
293
+ counting_left_teammate_num += 1
294
+
295
+ if active_x > obs['ball'][0] + 0.05:
296
+
297
+ if counting_left_teammate_num < 2:
298
+
299
+ if obs['ball_owned_team'] != 0:
300
+ should_left = True
301
+ if should_left:
302
+ available_action = get_direction_action(available_action, sticky_actions, RIGHT_ACTIONS, [LEFT, BOTTOM_LEFT, TOP_LEFT], active_direction, True)
303
+
304
+
305
+ if (abs(relative_ball_position[0]) > 0.75 or abs(relative_ball_position[1]) > 0.5):
306
+ all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
307
+ best_direction = np.argmax([np.dot(relative_ball_position, v) for v in all_directions_vecs])
308
+ target_direction = ALL_DIRECTION_ACTIONS[best_direction]
309
+ forbidden_actions = ALL_DIRECTION_ACTIONS.copy()
310
+ forbidden_actions.remove(target_direction)
311
+ available_action = get_direction_action(available_action, sticky_actions, forbidden_actions, [target_direction], active_direction, True)
312
+
313
+
314
+ if_i_hold_ball = (obs["ball_owned_team"] == 0 and obs["ball_owned_player"] == obs['active'])
315
+ ball_pos_offset = 0.05
316
+ no_ball_pos_offset = 0.03
317
+
318
+ active_x, active_y = active_position[0], active_position[1]
319
+ if_outside = False
320
+ if active_x <= (-1 + no_ball_pos_offset) or (if_i_hold_ball and active_x <= (-1 + ball_pos_offset)):
321
+ if_outside = True
322
+ action_index = LEFT_ACTIONS
323
+ target_direction = RIGHT
324
+ elif active_x >= (1 - no_ball_pos_offset) or (if_i_hold_ball and active_x >= (1 - ball_pos_offset)):
325
+ if_outside = True
326
+ action_index = RIGHT_ACTIONS
327
+ target_direction = LEFT
328
+ elif active_y >= (0.42 - no_ball_pos_offset) or (if_i_hold_ball and active_y >= (0.42 - ball_pos_offset)):
329
+ if_outside = True
330
+ action_index = BOTTOM_ACTIONS
331
+ target_direction = TOP
332
+ elif active_y <= (-0.42 + no_ball_pos_offset) or (if_i_hold_ball and active_x <= (-0.42 + ball_pos_offset)):
333
+ if_outside = True
334
+ action_index = TOP_ACTIONS
335
+ target_direction = BOTTOM
336
+ if obs["game_mode"] in [1, 2, 3, 4, 5]:
337
+ left2ball = np.linalg.norm(obs["left_team"] - obs["ball"][:2], axis=1)
338
+ right2ball = np.linalg.norm(obs["right_team"] - obs["ball"][:2], axis=1)
339
+ if np.min(left2ball) < np.min(right2ball) and obs["active"] == np.argmin(left2ball):
340
+ if_outside = False
341
+ elif obs["game_mode"] in [6]:
342
+ if obs["ball"][0] > 0 and active_position[0] > BOX_X:
343
+ if_outside = False
344
+ if if_outside:
345
+ available_action = get_direction_action(available_action, sticky_actions, action_index, [target_direction], active_direction, False)
346
+
347
+ if np.sum(sticky_actions[:8]) == 0:
348
+ available_action[RELEASE_DIRECTION] = 0
349
+ if sticky_actions[8] == 0:
350
+ available_action[RELEASE_SPRINT] = 0
351
+ else:
352
+ available_action[SPRINT] = 0
353
+ if sticky_actions[9] == 0:
354
+ available_action[RELEASE_DRIBBLE] = 0
355
+ else:
356
+ available_action[DRIBBLE] = 0
357
+ if active_position[0] < 0.4 or abs(active_position[1]) > 0.3:
358
+ available_action[SHOT] = 0
359
+
360
+ if obs["game_mode"] == 0:
361
+ if obs["ball_owned_team"] == -1:
362
+ available_action[DRIBBLE] = 0
363
+ if distance2ball >= 0.05:
364
+ available_action[SLIDING] = 0
365
+ available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
366
+ elif obs["ball_owned_team"] == 0:
367
+ available_action[SLIDING] = 0
368
+ if distance2ball >= 0.05:
369
+ available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, DRIBBLE]] = 0
370
+ elif obs["ball_owned_team"] == 1:
371
+ available_action[DRIBBLE] = 0
372
+ if distance2ball >= 0.05:
373
+ available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, SLIDING]] = 0
374
+ elif obs["game_mode"] in [1, 2, 3, 4, 5]:
375
+ left2ball = np.linalg.norm(obs["left_team"] - obs["ball"][:2], axis=1)
376
+ right2ball = np.linalg.norm(obs["right_team"] - obs["ball"][:2], axis=1)
377
+ if np.min(left2ball) < np.min(right2ball) and obs["active"] == np.argmin(left2ball):
378
+ available_action[[SPRINT, RELEASE_SPRINT, SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
379
+ else:
380
+ available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
381
+ available_action[[SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
382
+ elif obs["game_mode"] == 6:
383
+ if obs["ball"][0] > 0 and active_position[0] > BOX_X:
384
+ available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS]] = 0
385
+ available_action[[SPRINT, RELEASE_SPRINT, SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
386
+ else:
387
+ available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
388
+ available_action[[SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
389
+
390
+
391
+ obs = np.concatenate([
392
+ active_id, # 1
393
+ active_info, # 87
394
+ ball_own_active_info, # 57
395
+ left_team, # 88
396
+ right_team, # 88
397
+ match_state, # 9
398
+ ])
399
+
400
+ share_obs = np.concatenate([
401
+ ball_info, # 12
402
+ ball_owned_player, # 23
403
+ left_team, # 88
404
+ right_team, # 88
405
+ match_state, # 9
406
+ ])
407
+
408
+ assert available_action.sum() > 0
409
+ return dict(
410
+ obs=obs,
411
+ share_obs=share_obs,
412
+ available_action=available_action,
413
+ )
414
+
415
+
416
+ def _t2n(x):
417
+ return x.detach().cpu().numpy()
418
+
419
+
420
+
421
+
submission.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2023 The OpenRL Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # https://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ """"""
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+ import numpy as np
23
+ import torch
24
+
25
+ base_dir = Path(__file__).resolve().parent
26
+ sys.path.append(str(base_dir))
27
+
28
+ from openrl_policy import PolicyNetwork
29
+ from openrl_utils import openrl_obs_deal, _t2n
30
+ from goal_keeper import agent_get_action
31
+
32
+ class OpenRLAgent():
33
+ def __init__(self):
34
+ rnn_shape = [1,1,1,512]
35
+ self.rnn_hidden_state = [np.zeros(rnn_shape, dtype=np.float32) for _ in range (11)]
36
+ self.model = PolicyNetwork()
37
+ self.model.load_state_dict(torch.load( os.path.dirname(os.path.abspath(__file__)) + '/actor.pt', map_location=torch.device("cpu")))
38
+ self.model.eval()
39
+
40
+ def get_action(self,raw_obs,idx):
41
+ if idx == 0:
42
+ re_action = [[0]*19]
43
+ re_action_index = agent_get_action(raw_obs)[0]
44
+ re_action[0][re_action_index] = 1
45
+ return re_action
46
+
47
+ openrl_obs = openrl_obs_deal(raw_obs)
48
+
49
+ obs = openrl_obs['obs']
50
+ obs = np.concatenate(obs.reshape(1, 1, 330))
51
+ rnn_hidden_state = np.concatenate(self.rnn_hidden_state[idx])
52
+ avail_actions = np.zeros(20)
53
+ avail_actions[:19] = openrl_obs['available_action']
54
+ avail_actions = np.concatenate(avail_actions.reshape([1, 1, 20]))
55
+ with torch.no_grad():
56
+ actions, rnn_hidden_state = self.model(obs, rnn_hidden_state, available_actions=avail_actions, deterministic=True)
57
+ if actions[0][0] == 17 and raw_obs["sticky_actions"][8] == 1:
58
+ actions[0][0] = 15
59
+ self.rnn_hidden_state[idx] = np.array(np.split(_t2n(rnn_hidden_state), 1))
60
+
61
+ re_action = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
62
+ re_action[0][actions[0]] = 1
63
+
64
+ return re_action
65
+
66
+ agent = OpenRLAgent()
67
+
68
+ def my_controller(obs_list, action_space_list, is_act_continuous=False):
69
+ idx = obs_list['controlled_player_index'] % 11
70
+ del obs_list['controlled_player_index']
71
+ action = agent.get_action(obs_list,idx)
72
+ return action
73
+
74
+ def jidi_controller(obs_list=None):
75
+ if obs_list is None:
76
+ return
77
+ #重命名,防止加载错误
78
+ re = my_controller(obs_list,None)
79
+ assert isinstance(re,list)
80
+ assert isinstance(re[0],list)
81
+ return re