ClementBM commited on
Commit
ffe7549
1 Parent(s): 12d2509

first commit

Browse files
Files changed (47) hide show
  1. .vscode/extensions.json +7 -0
  2. .vscode/settings.json +28 -0
  3. README.md +1 -1
  4. connectfour/__init__.py +0 -0
  5. connectfour/__pycache__/__init__.cpython-38.pyc +0 -0
  6. connectfour/__pycache__/app.cpython-38.pyc +0 -0
  7. connectfour/app.py +250 -0
  8. connectfour/checkpoint/.Rhistory +0 -0
  9. connectfour/checkpoint/.is_checkpoint +0 -0
  10. connectfour/checkpoint/.tune_metadata +0 -0
  11. connectfour/checkpoint/__init__.py +3 -0
  12. connectfour/checkpoint/algorithm_state.pkl +3 -0
  13. connectfour/checkpoint/policies/always_same/policy_state.pkl +3 -0
  14. connectfour/checkpoint/policies/always_same/rllib_checkpoint.json +1 -0
  15. connectfour/checkpoint/policies/beat_last/policy_state.pkl +3 -0
  16. connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json +1 -0
  17. connectfour/checkpoint/policies/learned/policy_state.pkl +3 -0
  18. connectfour/checkpoint/policies/learned/rllib_checkpoint.json +1 -0
  19. connectfour/checkpoint/policies/learned_v1/policy_state.pkl +3 -0
  20. connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json +1 -0
  21. connectfour/checkpoint/policies/learned_v2/policy_state.pkl +3 -0
  22. connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json +1 -0
  23. connectfour/checkpoint/policies/learned_v3/policy_state.pkl +3 -0
  24. connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json +1 -0
  25. connectfour/checkpoint/policies/learned_v4/policy_state.pkl +3 -0
  26. connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json +1 -0
  27. connectfour/checkpoint/policies/learned_v5/policy_state.pkl +3 -0
  28. connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json +1 -0
  29. connectfour/checkpoint/policies/linear/policy_state.pkl +3 -0
  30. connectfour/checkpoint/policies/linear/rllib_checkpoint.json +1 -0
  31. connectfour/checkpoint/policies/random/policy_state.pkl +3 -0
  32. connectfour/checkpoint/policies/random/rllib_checkpoint.json +1 -0
  33. connectfour/checkpoint/rllib_checkpoint.json +1 -0
  34. connectfour/training/__init__.py +0 -0
  35. connectfour/training/__pycache__/__init__.cpython-38.pyc +0 -0
  36. connectfour/training/__pycache__/callbacks.cpython-38.pyc +0 -0
  37. connectfour/training/__pycache__/dummy_policies.cpython-38.pyc +0 -0
  38. connectfour/training/__pycache__/models.cpython-38.pyc +0 -0
  39. connectfour/training/__pycache__/wrappers.cpython-38.pyc +0 -0
  40. connectfour/training/callbacks.py +88 -0
  41. connectfour/training/dummy_policies.py +130 -0
  42. connectfour/training/models.py +119 -0
  43. connectfour/training/train.py +140 -0
  44. connectfour/training/wrappers.py +112 -0
  45. poetry.lock +0 -0
  46. pyproject.toml +37 -0
  47. requirements.txt +141 -0
.vscode/extensions.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ // See https://go.microsoft.com/fwlink/?LinkId=827846
3
+ // for the documentation about the extensions.json format
4
+ "recommendations": [
5
+ "ms-python.python"
6
+ ]
7
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.linting.enabled": true,
3
+ "python.linting.mypyEnabled": true,
4
+ "python.pythonPath": "${env:PYTHON_VENV_LOC}",
5
+ "python.testing.unittestEnabled": false,
6
+ "python.testing.nosetestsEnabled": false,
7
+ "python.testing.pytestEnabled": true,
8
+ "python.testing.pytestArgs": [
9
+ "tests"
10
+ ],
11
+ "editor.tabSize": 4,
12
+ "[python]": {
13
+ "editor.formatOnSave": true
14
+ },
15
+ "python.formatting.provider": "black",
16
+ "files.exclude": {
17
+ ".mypy_cache": true,
18
+ ".pytest_cache": true,
19
+ ".venv": true,
20
+ "**/__pycache__": true
21
+ },
22
+ "files.watcherExclude": {
23
+ ".venv/**": true,
24
+ "**/__pycache__/**": true,
25
+ ".mypy_cache/**": true,
26
+ ".pytest_cache/**": true
27
+ }
28
+ }
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.23.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.23.0
8
+ app_file: connectfour/app.py
9
  pinned: false
10
  ---
11
 
connectfour/__init__.py ADDED
File without changes
connectfour/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
connectfour/__pycache__/app.cpython-38.pyc ADDED
Binary file (6.51 kB). View file
 
connectfour/app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ray.serve.gradio_integrations import GradioServer, GradioIngress
3
+
4
+ import gradio as gr
5
+ from pettingzoo.classic import connect_four_v3
6
+ import ray.rllib.algorithms.ppo as ppo
7
+ import numpy as np
8
+ import time
9
+ from ray.tune import register_env
10
+ from connectfour.training.models import Connect4MaskModel
11
+ from connectfour.checkpoint import CHECKPOINT
12
+
13
+ from connectfour.training.wrappers import Connect4Env
14
+
15
+ demo = gr.Blocks()
16
+
17
+ POLICY_ID = "learned_v5"
18
+
19
+
20
+ class Connect4:
21
+ def __init__(self, who_plays_first) -> None:
22
+ self.init_env(who_plays_first)
23
+
24
+ def init_env(self, who_plays_first):
25
+ # define how to make the environment
26
+ env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
27
+
28
+ # register that way to make the environment under an rllib name
29
+ register_env("connect4", lambda config: Connect4Env(env_creator(config)))
30
+
31
+ orig_env = connect_four_v3.env(render_mode="rgb_array")
32
+ self.env = Connect4Env(orig_env)
33
+
34
+ self.done = False
35
+ self.obs, info = self.env.reset()
36
+
37
+ if who_plays_first == "You":
38
+ self.human = self.player_id
39
+ else:
40
+ self.play()
41
+ self.human = self.player_id
42
+
43
+ return self.render_and_state
44
+
45
+ def get_algo(self, checkpoint):
46
+ config = (
47
+ ppo.PPOConfig()
48
+ .environment("connect4")
49
+ .framework("torch")
50
+ .training(model={"custom_model": Connect4MaskModel})
51
+ )
52
+ config.explore = False
53
+ self.algo = config.build()
54
+ self.algo.restore(checkpoint)
55
+
56
+ def play(self, action=None):
57
+ if self.human != self.player_id:
58
+ action = self.algo.compute_single_action(
59
+ self.obs[self.player_id], policy_id=POLICY_ID
60
+ )
61
+
62
+ if action not in self.legal_moves:
63
+ action = np.random.choice(self.legal_moves)
64
+
65
+ player_actions = {self.player_id: action}
66
+
67
+ self.obs, self.reward, terminated, truncated, info = self.env.step(
68
+ player_actions
69
+ )
70
+ self.done = terminated["__all__"] or truncated["__all__"]
71
+ return self.render_and_state
72
+
73
+ @property
74
+ def render_and_state(self):
75
+ end_message = "End of the game"
76
+ if self.done:
77
+ if self.reward[self.human] > 0:
78
+ end_message += ": You WIN !!"
79
+ elif self.reward[self.human] < 0:
80
+ end_message += ": You LOSE !!"
81
+ return self.env.render(), end_message
82
+
83
+ return self.env.render(), "Game On"
84
+
85
+ @property
86
+ def player_id(self):
87
+ return list(self.obs.keys())[0]
88
+
89
+ @property
90
+ def legal_moves(self):
91
+ return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
92
+
93
+
94
+ with demo:
95
+ connect4 = Connect4("You")
96
+ connect4.get_algo(str(CHECKPOINT))
97
+
98
+ with gr.Row():
99
+ with gr.Column(scale=1):
100
+ gr.Markdown("# Let's Play Connect Four !")
101
+
102
+ who_plays_first = gr.Radio(
103
+ label="Who plays first", choices=["You", "Bot"], value="You"
104
+ )
105
+ reinitialize = gr.Button("New Game")
106
+
107
+ game_state = gr.Text(value="Game On", interactive=False, label="Status")
108
+
109
+ with gr.Column(scale=1):
110
+ output = gr.Image(
111
+ label="Connect Four Grid",
112
+ type="numpy",
113
+ show_label=False,
114
+ value=connect4.env.render(),
115
+ )
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=1, min_width=20):
119
+ drop_token0_btn = gr.Button("X")
120
+ with gr.Column(scale=1, min_width=20):
121
+ drop_token1_btn = gr.Button("X")
122
+ with gr.Column(scale=1, min_width=20):
123
+ drop_token2_btn = gr.Button("X")
124
+ with gr.Column(scale=1, min_width=20):
125
+ drop_token3_btn = gr.Button("X")
126
+ with gr.Column(scale=1, min_width=20):
127
+ drop_token4_btn = gr.Button("X")
128
+ with gr.Column(scale=1, min_width=20):
129
+ drop_token5_btn = gr.Button("X")
130
+ with gr.Column(scale=1, min_width=20):
131
+ drop_token6_btn = gr.Button("X")
132
+
133
+ who_plays_first.change(
134
+ connect4.init_env, who_plays_first, outputs=[output, game_state]
135
+ )
136
+
137
+ def reinit_game(who_plays_first):
138
+ output, game_state = connect4.init_env(who_plays_first)
139
+ return output, game_state, gr.Checkbox.update(interactive=True)
140
+
141
+ reinitialize.click(
142
+ reinit_game, who_plays_first, outputs=[output, game_state, who_plays_first]
143
+ )
144
+
145
+ def wait(game_state_value):
146
+ if game_state_value == "Game On":
147
+ time.sleep(1)
148
+ return gr.Checkbox.update(interactive=False)
149
+ else:
150
+ return gr.Checkbox.update(interactive=True)
151
+
152
+ def bot(game_state_value):
153
+ if game_state_value == "Game On":
154
+ rendered_env = connect4.play()
155
+ return *rendered_env, gr.Checkbox.update(interactive=False) if rendered_env[
156
+ 1
157
+ ] == "Game On" else gr.Checkbox.update(interactive=True)
158
+ return (
159
+ gr.Image.update(),
160
+ game_state_value,
161
+ gr.Checkbox.update(interactive=True),
162
+ )
163
+
164
+ drop_token0_btn.click(
165
+ lambda: connect4.play(0),
166
+ outputs=[output, game_state],
167
+ ).then(
168
+ wait, inputs=[game_state], outputs=who_plays_first
169
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
170
+
171
+ drop_token1_btn.click(
172
+ lambda: connect4.play(1),
173
+ outputs=[output, game_state],
174
+ ).then(
175
+ wait, inputs=[game_state], outputs=who_plays_first
176
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
177
+
178
+ drop_token2_btn.click(
179
+ lambda: connect4.play(2),
180
+ outputs=[output, game_state],
181
+ ).then(
182
+ wait, inputs=[game_state], outputs=who_plays_first
183
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
184
+
185
+ drop_token3_btn.click(
186
+ lambda: connect4.play(3),
187
+ outputs=[output, game_state],
188
+ ).then(
189
+ wait, inputs=[game_state], outputs=who_plays_first
190
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
191
+
192
+ drop_token4_btn.click(
193
+ lambda: connect4.play(4),
194
+ outputs=[output, game_state],
195
+ ).then(
196
+ wait, inputs=[game_state], outputs=who_plays_first
197
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
198
+
199
+ drop_token5_btn.click(
200
+ lambda: connect4.play(5),
201
+ outputs=[output, game_state],
202
+ ).then(
203
+ wait, inputs=[game_state], outputs=who_plays_first
204
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
205
+
206
+ drop_token6_btn.click(
207
+ lambda: connect4.play(6),
208
+ outputs=[output, game_state],
209
+ ).then(
210
+ wait, inputs=[game_state], outputs=who_plays_first
211
+ ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
212
+
213
+ def game_state_change(value):
214
+ if value == "Game On":
215
+ return [
216
+ gr.Button.update(interactive=True),
217
+ gr.Button.update(interactive=True),
218
+ gr.Button.update(interactive=True),
219
+ gr.Button.update(interactive=True),
220
+ gr.Button.update(interactive=True),
221
+ gr.Button.update(interactive=True),
222
+ gr.Button.update(interactive=True),
223
+ ]
224
+ else:
225
+ return [
226
+ gr.Button.update(interactive=False),
227
+ gr.Button.update(interactive=False),
228
+ gr.Button.update(interactive=False),
229
+ gr.Button.update(interactive=False),
230
+ gr.Button.update(interactive=False),
231
+ gr.Button.update(interactive=False),
232
+ gr.Button.update(interactive=False),
233
+ ]
234
+
235
+ game_state.change(
236
+ game_state_change,
237
+ game_state,
238
+ outputs=[
239
+ drop_token0_btn,
240
+ drop_token1_btn,
241
+ drop_token2_btn,
242
+ drop_token3_btn,
243
+ drop_token4_btn,
244
+ drop_token5_btn,
245
+ drop_token6_btn,
246
+ ],
247
+ )
248
+
249
+
250
+ demo.launch()
connectfour/checkpoint/.Rhistory ADDED
File without changes
connectfour/checkpoint/.is_checkpoint ADDED
File without changes
connectfour/checkpoint/.tune_metadata ADDED
Binary file (15.6 kB). View file
 
connectfour/checkpoint/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ CHECKPOINT = Path(__file__).parent.absolute()
connectfour/checkpoint/algorithm_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbbc198c3406897931f5f18046a88181b8abff1aedbea1d869329731c9a50853
3
+ size 66321
connectfour/checkpoint/policies/always_same/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d278413093ad1bc4f227279e3dab7be04ebd70ca1ed156a1363515c69d0a858e
3
+ size 10992
connectfour/checkpoint/policies/always_same/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/beat_last/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd422258c16de0866599730a5a5b2b48e2ee81cbae69f9d5471deeae76c42b47
3
+ size 10992
connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/learned/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a517583e5fcad7e483bca619723583cc6928499390c1fcfc25d907e109cd4b4
3
+ size 2139442
connectfour/checkpoint/policies/learned/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/learned_v1/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:276c26007c2419a688c27f9dfa70c20fecb468a0aa07d28d6a9e8099bbc849be
3
+ size 2139439
connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/learned_v2/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e37a485d3a54f7a8b194693e7a61f790e67071358130178fa01cdbd840c4a4da
3
+ size 2139439
connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/learned_v3/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f90899ae98a387e312333b234041c68b9c50da4af92ee5250686087a39eebb3
3
+ size 2139439
connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/learned_v4/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3af3b3fe41bac489cb693af387b1ccc4437a532a78d539b3abb4cc5f77929592
3
+ size 2139439
connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/learned_v5/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2b28b979e2f4411d196e03ca75ea7f25f7601bb997aa8bcdcf1d49c9ea30754
3
+ size 2139439
connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/linear/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f70d44ac661632dc0557204abe34308dfb25b800a668b49c2efd9a2a73a7bc0
3
+ size 10992
connectfour/checkpoint/policies/linear/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/policies/random/policy_state.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3b1ab86bada035779feedb2b92ae0a64f6d9474bb4f0ae44324e17d65659764
3
+ size 10992
connectfour/checkpoint/policies/random/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/checkpoint/rllib_checkpoint.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"type": "Algorithm", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
connectfour/training/__init__.py ADDED
File without changes
connectfour/training/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (182 Bytes). View file
 
connectfour/training/__pycache__/callbacks.cpython-38.pyc ADDED
Binary file (2.67 kB). View file
 
connectfour/training/__pycache__/dummy_policies.cpython-38.pyc ADDED
Binary file (5.36 kB). View file
 
connectfour/training/__pycache__/models.cpython-38.pyc ADDED
Binary file (2.91 kB). View file
 
connectfour/training/__pycache__/wrappers.cpython-38.pyc ADDED
Binary file (3.72 kB). View file
 
connectfour/training/callbacks.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.algorithms.callbacks import DefaultCallbacks
2
+ import numpy as np
3
+
4
+
5
+ def create_self_play_callback(win_rate_thr, opponent_policies):
6
+ class SelfPlayCallback(DefaultCallbacks):
7
+ win_rate_threshold = win_rate_thr
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.current_opponent = 0
12
+
13
+ def on_train_result(self, *, algorithm, result, **kwargs):
14
+ # Get the win rate for the train batch.
15
+ # Note that normally, one should set up a proper evaluation config,
16
+ # such that evaluation always happens on the already updated policy,
17
+ # instead of on the already used train_batch.
18
+ main_rew = result["hist_stats"].pop("policy_learned_reward")
19
+ opponent_rew = result["hist_stats"].pop("episode_reward")
20
+
21
+ if len(main_rew) != len(opponent_rew):
22
+ raise Exception(
23
+ "len(main_rew) != len(opponent_rew)",
24
+ len(main_rew),
25
+ len(opponent_rew),
26
+ result["hist_stats"].keys(),
27
+ "episode len",
28
+ len(opponent_rew),
29
+ )
30
+
31
+ won = 0
32
+ for r_main, r_opponent in zip(main_rew, opponent_rew):
33
+ if r_main > r_opponent:
34
+ won += 1
35
+ win_rate = won / len(main_rew)
36
+
37
+ result["win_rate"] = win_rate
38
+ print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="")
39
+
40
+ # If win rate is good -> Snapshot current policy and play against
41
+ # it next, keeping the snapshot fixed and only improving the "learned"
42
+ # policy.
43
+ if win_rate > self.win_rate_threshold:
44
+ self.current_opponent += 1
45
+ new_pol_id = f"learned_v{self.current_opponent}"
46
+ print(
47
+ f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})."
48
+ )
49
+
50
+ # Re-define the mapping function, such that "learned" is forced
51
+ # to play against any of the previously played policies
52
+ # (excluding "random").
53
+ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
54
+ # agent_id = [0|1] -> policy depends on episode ID
55
+ # This way, we make sure that both policies sometimes play
56
+ # (start player) and sometimes agent1 (player to move 2nd).
57
+ return (
58
+ "learned"
59
+ if episode.episode_id % 2 == int(agent_id[-1:])
60
+ else np.random.choice(
61
+ opponent_policies
62
+ + [
63
+ f"learned_v{i}"
64
+ for i in range(1, self.current_opponent + 1)
65
+ ]
66
+ )
67
+ )
68
+
69
+ new_policy = algorithm.add_policy(
70
+ policy_id=new_pol_id,
71
+ policy_cls=type(algorithm.get_policy("learned")),
72
+ policy_mapping_fn=policy_mapping_fn,
73
+ )
74
+
75
+ # Set the weights of the new policy to the learned policy.
76
+ # We'll keep training the learned policy, whereas `new_pol_id` will
77
+ # remain fixed.
78
+ learned_state = algorithm.get_policy("learned").get_state()
79
+ new_policy.set_state(learned_state)
80
+ # We need to sync the just copied local weights (from learned policy)
81
+ # to all the remote workers as well.
82
+ algorithm.workers.sync_weights()
83
+ else:
84
+ print("not good enough; will keep learning ...")
85
+
86
+ result["league_size"] = self.current_opponent + len(opponent_policies) + 1
87
+
88
+ return SelfPlayCallback
connectfour/training/dummy_policies.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from ray.rllib.policy.policy import Policy
4
+ from ray.rllib.utils.annotations import override
5
+ from ray.rllib.models.modelv2 import restore_original_dimensions
6
+
7
+
8
+ class HeuristicBase(Policy):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ self.exploration = self._create_exploration()
12
+
13
+ def learn_on_batch(self, samples):
14
+ pass
15
+
16
+ @override(Policy)
17
+ def get_weights(self):
18
+ """No weights to save."""
19
+ return {}
20
+
21
+ @override(Policy)
22
+ def set_weights(self, weights):
23
+ """No weights to set."""
24
+ pass
25
+
26
+ @override(Policy)
27
+ def compute_actions(
28
+ self,
29
+ obs_batch,
30
+ state_batches=None,
31
+ prev_action_batch=None,
32
+ prev_reward_batch=None,
33
+ info_batch=None,
34
+ episodes=None,
35
+ **kwargs
36
+ ):
37
+ obs_batch = restore_original_dimensions(
38
+ np.array(obs_batch, dtype=np.float32), self.observation_space, tensorlib=np
39
+ )
40
+ return self._do_compute_actions(obs_batch)
41
+
42
+ def pick_legal_action(self, legal_action):
43
+ legal_choices = np.arange(len(legal_action))[legal_action == 1]
44
+ return np.random.choice(legal_choices)
45
+
46
+
47
+ class AlwaysSameHeuristic(HeuristicBase):
48
+ """
49
+ Pick a random move and stick with it for the entire episode.
50
+ """
51
+
52
+ _rand_choice = random.choice(range(7))
53
+
54
+ def _do_compute_actions(self, obs_batch):
55
+ def select_action(legal_action):
56
+ legal_choices = np.arange(len(legal_action))[legal_action == 1]
57
+
58
+ if self._rand_choice not in legal_choices:
59
+ self._rand_choice = np.random.choice(legal_choices)
60
+
61
+ return self._rand_choice
62
+
63
+ return [select_action(x) for x in obs_batch["action_mask"]], [], {}
64
+
65
+
66
+ class LinearHeuristic(HeuristicBase):
67
+ """
68
+ Pick a random move and increment column index
69
+ """
70
+
71
+ _rand_choice = random.choice(range(7))
72
+ _rand_sign = np.random.choice([-1, 1])
73
+
74
+ def _do_compute_actions(self, obs_batch):
75
+ def select_action(legal_action):
76
+ legal_choices = np.arange(len(legal_action))[legal_action == 1]
77
+
78
+ self._rand_choice += 1 * self._rand_sign
79
+
80
+ if self._rand_choice not in legal_choices:
81
+ self._rand_choice = np.random.choice(legal_choices)
82
+
83
+ return self._rand_choice
84
+
85
+ return [select_action(x) for x in obs_batch["action_mask"]], [], {}
86
+
87
+
88
+ class BeatLastHeuristic(HeuristicBase):
89
+ """
90
+ Play the move the last move of the opponent.
91
+ """
92
+
93
+ def _do_compute_actions(self, obs_batch):
94
+ def select_action(legal_action, observation):
95
+ legal_choices = np.arange(len(legal_action))[legal_action == 1]
96
+
97
+ obs_sums = np.sum(observation, axis=0)
98
+
99
+ desired_actions = np.squeeze(np.argwhere(obs_sums[:, 0] < obs_sums[:, 1]))
100
+ if desired_actions.size == 0:
101
+ return np.random.choice(legal_choices)
102
+
103
+ if desired_actions.size == 1:
104
+ desired_action = desired_actions[()]
105
+ else:
106
+ desired_action = np.random.choice(desired_actions)
107
+ if desired_action in legal_choices:
108
+ return desired_action
109
+
110
+ return np.random.choice(legal_choices)
111
+
112
+ return (
113
+ [
114
+ select_action(x, y)
115
+ for x, y in zip(obs_batch["action_mask"], obs_batch["observation"])
116
+ ],
117
+ [],
118
+ {},
119
+ )
120
+
121
+
122
+ class RandomHeuristic(HeuristicBase):
123
+ """
124
+ Just pick a random legal action
125
+ The outputted state of the environment needs to be a dictionary with an
126
+ 'action_mask' key containing the legal actions for the agent.
127
+ """
128
+
129
+ def _do_compute_actions(self, obs_batch):
130
+ return [self.pick_legal_action(x) for x in obs_batch["action_mask"]], [], {}
connectfour/training/models.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
2
+ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
3
+ from gymnasium.spaces import Dict
4
+ from ray.rllib.utils.torch_utils import FLOAT_MIN
5
+ from ray.rllib.utils.framework import try_import_torch
6
+ from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel
7
+ from ray.rllib.utils import override
8
+
9
+ torch, nn = try_import_torch()
10
+
11
+
12
+ class Connect4MaskModel(TorchModelV2, nn.Module):
13
+ """PyTorch version of above ActionMaskingModel."""
14
+
15
+ def __init__(
16
+ self,
17
+ obs_space,
18
+ action_space,
19
+ num_outputs,
20
+ model_config,
21
+ name,
22
+ **kwargs,
23
+ ):
24
+ orig_space = getattr(obs_space, "original_space", obs_space)
25
+
26
+ assert isinstance(orig_space, Dict)
27
+ assert "action_mask" in orig_space.spaces
28
+ assert "observation" in orig_space.spaces
29
+
30
+ TorchModelV2.__init__(
31
+ self, obs_space, action_space, num_outputs, model_config, name, **kwargs
32
+ )
33
+ nn.Module.__init__(self)
34
+
35
+ self.internal_model = TorchFC(
36
+ orig_space["observation"],
37
+ action_space,
38
+ num_outputs,
39
+ model_config,
40
+ name + "_internal",
41
+ )
42
+
43
+ def forward(self, input_dict, state, seq_lens):
44
+ # Extract the available actions tensor from the observation.
45
+ action_mask = input_dict["obs"]["action_mask"]
46
+
47
+ # Compute the unmasked logits.
48
+ logits, _ = self.internal_model({"obs": input_dict["obs"]["observation"]})
49
+
50
+ # Convert action_mask into a [0.0 || -inf]-type mask.
51
+ inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
52
+ masked_logits = logits + inf_mask
53
+
54
+ # Return masked logits.
55
+ return masked_logits, state
56
+
57
+ def value_function(self):
58
+ return self.internal_model.value_function()
59
+
60
+
61
+ class SacConnect4MaskModel(SACTorchModel):
62
+ def __init__(
63
+ self,
64
+ obs_space,
65
+ action_space,
66
+ num_outputs,
67
+ model_config,
68
+ name: str,
69
+ policy_model_config=None,
70
+ q_model_config=None,
71
+ twin_q=False,
72
+ initial_alpha=1.0,
73
+ target_entropy=None,
74
+ **kwargs,
75
+ ):
76
+ orig_space = getattr(obs_space, "original_space", obs_space)
77
+
78
+ assert isinstance(orig_space, Dict)
79
+ assert "action_mask" in orig_space.spaces
80
+ assert "observation" in orig_space.spaces
81
+
82
+ super().__init__(
83
+ obs_space,
84
+ action_space,
85
+ num_outputs,
86
+ model_config,
87
+ policy_model_config,
88
+ q_model_config,
89
+ twin_q,
90
+ initial_alpha,
91
+ target_entropy,
92
+ **kwargs,
93
+ )
94
+
95
+ self.internal_model = TorchFC(
96
+ orig_space["observation"],
97
+ action_space,
98
+ num_outputs,
99
+ model_config,
100
+ name + "_internal",
101
+ )
102
+
103
+ @override(SACTorchModel)
104
+ def forward(self, input_dict, state, seq_lens):
105
+ # Extract the available actions tensor from the observation.
106
+ action_mask = input_dict["obs"]["action_mask"]
107
+
108
+ # Compute the unmasked logits.
109
+ logits, _ = self.internal_model({"obs": input_dict["obs"]["observation"]})
110
+
111
+ # Convert action_mask into a [0.0 || -inf]-type mask.
112
+ inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
113
+ masked_logits = logits + inf_mask
114
+
115
+ # Return masked logits.
116
+ return masked_logits, state
117
+
118
+ def value_function(self):
119
+ return self.internal_model.value_function()
connectfour/training/train.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+
4
+ import ray
5
+ import ray.rllib.algorithms.ppo as ppo
6
+ from pettingzoo.classic import connect_four_v3
7
+ from ray import air, tune
8
+ from ray.rllib.policy.policy import PolicySpec
9
+ from ray.rllib.utils.framework import try_import_torch
10
+ from ray.tune import CLIReporter, register_env
11
+
12
+ from connectfour.training.callbacks import create_self_play_callback
13
+ from connectfour.training.dummy_policies import (
14
+ AlwaysSameHeuristic,
15
+ BeatLastHeuristic,
16
+ LinearHeuristic,
17
+ RandomHeuristic,
18
+ )
19
+ from connectfour.training.models import Connect4MaskModel
20
+ from connectfour.training.wrappers import Connect4Env
21
+
22
+ torch, nn = try_import_torch()
23
+
24
+
25
+ def get_cli_args():
26
+ """
27
+ Create CLI parser and return parsed arguments
28
+
29
+ python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
30
+ python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
31
+ python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
32
+ """
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--num-cpus", type=int, default=0)
35
+ parser.add_argument("--num-gpus", type=int, default=0)
36
+ parser.add_argument("--num-workers", type=int, default=2)
37
+
38
+ parser.add_argument(
39
+ "--stop-iters", type=int, default=200, help="Number of iterations to train."
40
+ )
41
+ parser.add_argument(
42
+ "--stop-timesteps",
43
+ type=int,
44
+ default=10000000,
45
+ help="Number of timesteps to train.",
46
+ )
47
+ parser.add_argument(
48
+ "--win-rate-threshold",
49
+ type=float,
50
+ default=0.95,
51
+ help="Win-rate at which we setup another opponent by freezing the "
52
+ "current main policy and playing against a uniform distribution "
53
+ "of previously frozen 'main's from here on.",
54
+ )
55
+ args = parser.parse_args()
56
+ print(f"Running with following CLI args: {args}")
57
+ return args
58
+
59
+
60
+ def select_policy(agent_id, episode, **kwargs):
61
+ if episode.episode_id % 2 == int(agent_id[-1:]):
62
+ return "learned"
63
+ else:
64
+ return random.choice(["always_same", "beat_last", "random", "linear"])
65
+
66
+
67
+ if __name__ == "__main__":
68
+ args = get_cli_args()
69
+
70
+ ray.init(
71
+ num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False
72
+ )
73
+
74
+ # define how to make the environment
75
+ env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
76
+
77
+ # register that way to make the environment under an rllib name
78
+ register_env("connect4", lambda config: Connect4Env(env_creator(config)))
79
+
80
+ config = (
81
+ ppo.PPOConfig()
82
+ .environment("connect4")
83
+ .framework("torch")
84
+ .training(model={"custom_model": Connect4MaskModel})
85
+ .callbacks(
86
+ create_self_play_callback(
87
+ win_rate_thr=args.win_rate_threshold,
88
+ opponent_policies=["always_same", "beat_last", "random", "linear"],
89
+ )
90
+ )
91
+ .rollouts(
92
+ num_rollout_workers=args.num_workers,
93
+ num_envs_per_worker=5,
94
+ )
95
+ .multi_agent(
96
+ policies={
97
+ "learned": PolicySpec(),
98
+ "always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
99
+ "linear": PolicySpec(policy_class=LinearHeuristic),
100
+ "beat_last": PolicySpec(policy_class=BeatLastHeuristic),
101
+ "random": PolicySpec(policy_class=RandomHeuristic),
102
+ },
103
+ policy_mapping_fn=select_policy,
104
+ policies_to_train=["learned"],
105
+ )
106
+ )
107
+
108
+ stop = {
109
+ "timesteps_total": args.stop_timesteps,
110
+ "training_iteration": args.stop_iters,
111
+ }
112
+
113
+ results = tune.Tuner(
114
+ "PPO",
115
+ param_space=config.to_dict(),
116
+ run_config=air.RunConfig(
117
+ stop=stop,
118
+ verbose=2,
119
+ progress_reporter=CLIReporter(
120
+ metric_columns={
121
+ "training_iteration": "iter",
122
+ "time_total_s": "time_total_s",
123
+ "timesteps_total": "ts",
124
+ "episodes_this_iter": "train_episodes",
125
+ "policy_reward_mean/learned": "reward",
126
+ "win_rate": "win_rate",
127
+ "league_size": "league_size",
128
+ },
129
+ sort_by_metric=True,
130
+ ),
131
+ checkpoint_config=air.CheckpointConfig(
132
+ checkpoint_at_end=True,
133
+ checkpoint_frequency=10,
134
+ ),
135
+ ),
136
+ ).fit()
137
+
138
+ print("Best checkpoint", results.get_best_result().checkpoint)
139
+
140
+ ray.shutdown()
connectfour/training/wrappers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from ray.rllib.env.multi_agent_env import MultiAgentEnv
4
+ from ray.rllib.utils.annotations import PublicAPI
5
+ from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space
6
+
7
+
8
+ @PublicAPI
9
+ class Connect4Env(MultiAgentEnv):
10
+ """An interface to the PettingZoo MARL environment library.
11
+
12
+ See: https://github.com/Farama-Foundation/PettingZoo
13
+
14
+ Inherits from MultiAgentEnv and exposes a given AEC
15
+ (actor-environment-cycle) game from the PettingZoo project via the
16
+ MultiAgentEnv public API.
17
+
18
+ Note that the wrapper has some important limitations:
19
+
20
+ 1. All agents have the same action_spaces and observation_spaces.
21
+ Note: If, within your aec game, agents do not have homogeneous action /
22
+ observation spaces, apply SuperSuit wrappers
23
+ to apply padding functionality: https://github.com/Farama-Foundation/
24
+ SuperSuit#built-in-multi-agent-only-functions
25
+ 2. Environments are positive sum games (-> Agents are expected to cooperate
26
+ to maximize reward). This isn't a hard restriction, it just that
27
+ standard algorithms aren't expected to work well in highly competitive
28
+ games."""
29
+
30
+ def __init__(self, env):
31
+ super().__init__()
32
+ self.env = env
33
+ env.reset()
34
+
35
+ # Since all agents have the same spaces, do not provide full observation-
36
+ # and action-spaces as Dicts, mapping agent IDs to the individual
37
+ # agents' spaces. Instead, `self.[action|observation]_space` are the single
38
+ # agent spaces.
39
+ self._obs_space_in_preferred_format = False
40
+ self._action_space_in_preferred_format = False
41
+
42
+ # Collect the individual agents' spaces (they should all be the same):
43
+ first_obs_space = self.env.observation_space(self.env.agents[0])
44
+ first_action_space = self.env.action_space(self.env.agents[0])
45
+
46
+ for agent in self.env.agents:
47
+ if self.env.observation_space(agent) != first_obs_space:
48
+ raise ValueError(
49
+ "Observation spaces for all agents must be identical. Perhaps "
50
+ "SuperSuit's pad_observations wrapper can help (useage: "
51
+ "`supersuit.aec_wrappers.pad_observations(env)`"
52
+ )
53
+ if self.env.action_space(agent) != first_action_space:
54
+ raise ValueError(
55
+ "Action spaces for all agents must be identical. Perhaps "
56
+ "SuperSuit's pad_action_space wrapper can help (usage: "
57
+ "`supersuit.aec_wrappers.pad_action_space(env)`"
58
+ )
59
+
60
+ # Convert from gym to gymnasium, if necessary.
61
+ self.observation_space = convert_old_gym_space_to_gymnasium_space(
62
+ first_obs_space
63
+ )
64
+ self.action_space = convert_old_gym_space_to_gymnasium_space(first_action_space)
65
+
66
+ self._agent_ids = set(self.env.agents)
67
+
68
+ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
69
+ info = self.env.reset(seed=seed, options=options)
70
+ return (
71
+ {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
72
+ info or {},
73
+ )
74
+
75
+ def step(self, action):
76
+ self.env.step(action[self.env.agent_selection])
77
+ obs_d = {}
78
+ rew_d = {}
79
+ terminated_d = {}
80
+ truncated_d = {}
81
+ info_d = {}
82
+ while self.env.agents:
83
+ obs, rew, terminated, truncated, info = self.env.last()
84
+ agent_id = self.env.agent_selection
85
+ obs_d[agent_id] = obs
86
+ rew_d[agent_id] = rew
87
+ terminated_d[agent_id] = terminated
88
+ truncated_d[agent_id] = truncated
89
+ info_d[agent_id] = info
90
+ if (
91
+ self.env.terminations[self.env.agent_selection]
92
+ or self.env.truncations[self.env.agent_selection]
93
+ ):
94
+ self.env.step(None)
95
+ else:
96
+ break
97
+
98
+ all_gone = not self.env.agents
99
+ terminated_d["__all__"] = all_gone and all(terminated_d.values())
100
+ truncated_d["__all__"] = all_gone and all(truncated_d.values())
101
+
102
+ return obs_d, rew_d, terminated_d, truncated_d, info_d
103
+
104
+ def close(self):
105
+ self.env.close()
106
+
107
+ def render(self):
108
+ return self.env.render()
109
+
110
+ @property
111
+ def get_sub_environments(self):
112
+ return self.env.unwrapped
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://python-poetry.org/docs/pyproject/
2
+ [tool.poetry]
3
+ name = "connectfour"
4
+ version = "0.1.0"
5
+ description = "Connect Four"
6
+ authors = ["Clément Brutti-Mairesse <clement.brutti.mairesse@gmail.com>"]
7
+ license = "MIT"
8
+ readme = "README.md"
9
+ homepage = "https://huggingface.co/spaces/ClementBM/connectfour"
10
+ repository = "https://huggingface.co/spaces/ClementBM/connectfour"
11
+ keywords = ["connectfour", "connect4", "reinforcement learning"]
12
+ include = [
13
+ "LICENSE",
14
+ ]
15
+
16
+ [tool.poetry.dependencies]
17
+ python = ">=3.8,<3.11"
18
+ orjson = "3.8.8"
19
+ gradio = "^3.23.0"
20
+ ray = {extras = ["rllib", "serve"], version = "^2.2.0"}
21
+ pettingzoo = "^1.22.4"
22
+ pygame = "^2.3.0"
23
+ torch = "^2.0.0"
24
+ libclang = "15.0.6.1"
25
+ tensorflow-probability = "^0.19.0"
26
+ protobuf = "3.17.0"
27
+ scipy = ">=1.8,<1.9.2"
28
+
29
+ [tool.poetry.dev-dependencies]
30
+ pylint = "*"
31
+ pytest = "*"
32
+ mypy = "*"
33
+ black = "*"
34
+
35
+ [build-system]
36
+ requires = ["poetry-core>=1.0.0"]
37
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11"
2
+ aiofiles==23.1.0 ; python_version >= "3.8" and python_version < "3.11"
3
+ aiohttp-cors==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
4
+ aiohttp==3.8.4 ; python_version >= "3.8" and python_version < "3.11"
5
+ aiorwlock==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
6
+ aiosignal==1.3.1 ; python_version >= "3.8" and python_version < "3.11"
7
+ altair==4.2.2 ; python_version >= "3.8" and python_version < "3.11"
8
+ ansicon==1.89.0 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
9
+ anyio==3.6.2 ; python_version >= "3.8" and python_version < "3.11"
10
+ async-timeout==4.0.2 ; python_version >= "3.8" and python_version < "3.11"
11
+ attrs==22.2.0 ; python_version >= "3.8" and python_version < "3.11"
12
+ blessed==1.20.0 ; python_version >= "3.8" and python_version < "3.11"
13
+ cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11"
14
+ certifi==2022.12.7 ; python_version >= "3.8" and python_version < "3.11"
15
+ charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11"
16
+ click==8.1.3 ; python_version >= "3.8" and python_version < "3.11"
17
+ cloudpickle==2.2.1 ; python_version >= "3.8" and python_version < "3.11"
18
+ cmake==3.26.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
19
+ colorama==0.4.6 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
20
+ colorful==0.5.5 ; python_version >= "3.8" and python_version < "3.11"
21
+ contourpy==1.0.7 ; python_version >= "3.8" and python_version < "3.11"
22
+ cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11"
23
+ decorator==5.1.1 ; python_version >= "3.8" and python_version < "3.11"
24
+ distlib==0.3.6 ; python_version >= "3.8" and python_version < "3.11"
25
+ dm-tree==0.1.8 ; python_version >= "3.8" and python_version < "3.11"
26
+ entrypoints==0.4 ; python_version >= "3.8" and python_version < "3.11"
27
+ fastapi==0.95.0 ; python_version >= "3.8" and python_version < "3.11"
28
+ ffmpy==0.3.0 ; python_version >= "3.8" and python_version < "3.11"
29
+ filelock==3.10.7 ; python_version >= "3.8" and python_version < "3.11"
30
+ fonttools==4.39.3 ; python_version >= "3.8" and python_version < "3.11"
31
+ frozenlist==1.3.3 ; python_version >= "3.8" and python_version < "3.11"
32
+ fsspec==2023.3.0 ; python_version >= "3.8" and python_version < "3.11"
33
+ gast==0.5.3 ; python_version >= "3.8" and python_version < "3.11"
34
+ google-api-core==2.8.2 ; python_version >= "3.8" and python_version < "3.11"
35
+ google-auth==2.17.0 ; python_version >= "3.8" and python_version < "3.11"
36
+ googleapis-common-protos==1.56.4 ; python_version >= "3.8" and python_version < "3.11"
37
+ gpustat==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
38
+ gradio==3.23.0 ; python_version >= "3.8" and python_version < "3.11"
39
+ grpcio==1.49.1 ; python_version >= "3.8" and python_version < "3.11" and sys_platform == "darwin"
40
+ grpcio==1.53.0 ; python_version >= "3.8" and python_version < "3.11" and sys_platform != "darwin"
41
+ gymnasium-notices==0.0.1 ; python_version >= "3.8" and python_version < "3.11"
42
+ gymnasium==0.26.3 ; python_version >= "3.8" and python_version < "3.11"
43
+ h11==0.14.0 ; python_version >= "3.8" and python_version < "3.11"
44
+ httpcore==0.16.3 ; python_version >= "3.8" and python_version < "3.11"
45
+ httpx==0.23.3 ; python_version >= "3.8" and python_version < "3.11"
46
+ huggingface-hub==0.13.3 ; python_version >= "3.8" and python_version < "3.11"
47
+ idna==3.4 ; python_version >= "3.8" and python_version < "3.11"
48
+ imageio==2.27.0 ; python_version >= "3.8" and python_version < "3.11"
49
+ importlib-metadata==6.1.0 ; python_version >= "3.8" and python_version < "3.10"
50
+ importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.10"
51
+ jinja2==3.1.2 ; python_version >= "3.8" and python_version < "3.11"
52
+ jinxed==1.2.0 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
53
+ jsonschema==4.17.3 ; python_version >= "3.8" and python_version < "3.11"
54
+ kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11"
55
+ lazy-loader==0.2 ; python_version >= "3.8" and python_version < "3.11"
56
+ libclang==15.0.6.1 ; python_version >= "3.8" and python_version < "3.11"
57
+ linkify-it-py==2.0.0 ; python_version >= "3.8" and python_version < "3.11"
58
+ lit==16.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
59
+ lz4==4.3.2 ; python_version >= "3.8" and python_version < "3.11"
60
+ markdown-it-py==2.2.0 ; python_version >= "3.8" and python_version < "3.11"
61
+ markdown-it-py[linkify]==2.2.0 ; python_version >= "3.8" and python_version < "3.11"
62
+ markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11"
63
+ matplotlib==3.7.1 ; python_version >= "3.8" and python_version < "3.11"
64
+ mdit-py-plugins==0.3.3 ; python_version >= "3.8" and python_version < "3.11"
65
+ mdurl==0.1.2 ; python_version >= "3.8" and python_version < "3.11"
66
+ mpmath==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
67
+ msgpack==1.0.5 ; python_version >= "3.8" and python_version < "3.11"
68
+ multidict==6.0.4 ; python_version >= "3.8" and python_version < "3.11"
69
+ networkx==3.0 ; python_version >= "3.8" and python_version < "3.11"
70
+ numpy==1.24.2 ; python_version < "3.11" and python_version >= "3.8"
71
+ nvidia-cublas-cu11==11.10.3.66 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
72
+ nvidia-cuda-cupti-cu11==11.7.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
73
+ nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
74
+ nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
75
+ nvidia-cudnn-cu11==8.5.0.96 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
76
+ nvidia-cufft-cu11==10.9.0.58 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
77
+ nvidia-curand-cu11==10.2.10.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
78
+ nvidia-cusolver-cu11==11.4.0.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
79
+ nvidia-cusparse-cu11==11.7.4.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
80
+ nvidia-ml-py==11.495.46 ; python_version >= "3.8" and python_version < "3.11"
81
+ nvidia-nccl-cu11==2.14.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
82
+ nvidia-nvtx-cu11==11.7.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
83
+ opencensus-context==0.1.3 ; python_version >= "3.8" and python_version < "3.11"
84
+ opencensus==0.11.2 ; python_version >= "3.8" and python_version < "3.11"
85
+ orjson==3.8.8 ; python_version >= "3.8" and python_version < "3.11"
86
+ packaging==23.0 ; python_version < "3.11" and python_version >= "3.8"
87
+ pandas==1.5.3 ; python_version >= "3.8" and python_version < "3.11"
88
+ pettingzoo==1.22.4 ; python_version >= "3.8" and python_version < "3.11"
89
+ pillow==9.4.0 ; python_version >= "3.8" and python_version < "3.11"
90
+ pkgutil-resolve-name==1.3.10 ; python_version >= "3.8" and python_version < "3.9"
91
+ platformdirs==3.2.0 ; python_version >= "3.8" and python_version < "3.11"
92
+ prometheus-client==0.16.0 ; python_version >= "3.8" and python_version < "3.11"
93
+ protobuf==3.17.0 ; python_version >= "3.8" and python_version < "3.11"
94
+ psutil==5.9.4 ; python_version >= "3.8" and python_version < "3.11"
95
+ py-spy==0.3.14 ; python_version >= "3.8" and python_version < "3.11"
96
+ pyasn1-modules==0.2.8 ; python_version >= "3.8" and python_version < "3.11"
97
+ pyasn1==0.4.8 ; python_version >= "3.8" and python_version < "3.11"
98
+ pydantic==1.10.7 ; python_version >= "3.8" and python_version < "3.11"
99
+ pydub==0.25.1 ; python_version >= "3.8" and python_version < "3.11"
100
+ pygame==2.3.0 ; python_version >= "3.8" and python_version < "3.11"
101
+ pygments==2.14.0 ; python_version >= "3.8" and python_version < "3.11"
102
+ pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11"
103
+ pyrsistent==0.19.3 ; python_version >= "3.8" and python_version < "3.11"
104
+ python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11"
105
+ python-multipart==0.0.6 ; python_version >= "3.8" and python_version < "3.11"
106
+ pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11"
107
+ pywavelets==1.4.1 ; python_version >= "3.8" and python_version < "3.11"
108
+ pyyaml==6.0 ; python_version >= "3.8" and python_version < "3.11"
109
+ ray[rllib,serve]==2.3.1 ; python_version >= "3.8" and python_version < "3.11"
110
+ requests==2.28.2 ; python_version >= "3.8" and python_version < "3.11"
111
+ rfc3986[idna2008]==1.5.0 ; python_version >= "3.8" and python_version < "3.11"
112
+ rich==13.3.3 ; python_version >= "3.8" and python_version < "3.11"
113
+ rsa==4.9 ; python_version >= "3.8" and python_version < "3.11"
114
+ scikit-image==0.20.0 ; python_version >= "3.8" and python_version < "3.11"
115
+ scipy==1.9.1 ; python_version < "3.11" and python_version >= "3.8"
116
+ semantic-version==2.10.0 ; python_version >= "3.8" and python_version < "3.11"
117
+ setuptools==67.6.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
118
+ six==1.16.0 ; python_version < "3.11" and python_version >= "3.8"
119
+ smart-open==6.3.0 ; python_version >= "3.8" and python_version < "3.11"
120
+ sniffio==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
121
+ starlette==0.26.1 ; python_version >= "3.8" and python_version < "3.11"
122
+ sympy==1.11.1 ; python_version >= "3.8" and python_version < "3.11"
123
+ tabulate==0.9.0 ; python_version >= "3.8" and python_version < "3.11"
124
+ tensorboardx==2.6 ; python_version >= "3.8" and python_version < "3.11"
125
+ tensorflow-probability==0.19.0 ; python_version >= "3.8" and python_version < "3.11"
126
+ tifffile==2023.3.21 ; python_version >= "3.8" and python_version < "3.11"
127
+ toolz==0.12.0 ; python_version >= "3.8" and python_version < "3.11"
128
+ torch==2.0.0 ; python_version >= "3.8" and python_version < "3.11"
129
+ tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11"
130
+ triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
131
+ typer==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
132
+ typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11"
133
+ uc-micro-py==1.0.1 ; python_version >= "3.8" and python_version < "3.11"
134
+ urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11"
135
+ uvicorn==0.21.1 ; python_version >= "3.8" and python_version < "3.11"
136
+ virtualenv==20.21.0 ; python_version >= "3.8" and python_version < "3.11"
137
+ wcwidth==0.2.6 ; python_version >= "3.8" and python_version < "3.11"
138
+ websockets==10.4 ; python_version >= "3.8" and python_version < "3.11"
139
+ wheel==0.40.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
140
+ yarl==1.8.2 ; python_version >= "3.8" and python_version < "3.11"
141
+ zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10"