Spaces:
Runtime error
Runtime error
first commit
Browse files- .vscode/extensions.json +7 -0
- .vscode/settings.json +28 -0
- README.md +1 -1
- connectfour/__init__.py +0 -0
- connectfour/__pycache__/__init__.cpython-38.pyc +0 -0
- connectfour/__pycache__/app.cpython-38.pyc +0 -0
- connectfour/app.py +250 -0
- connectfour/checkpoint/.Rhistory +0 -0
- connectfour/checkpoint/.is_checkpoint +0 -0
- connectfour/checkpoint/.tune_metadata +0 -0
- connectfour/checkpoint/__init__.py +3 -0
- connectfour/checkpoint/algorithm_state.pkl +3 -0
- connectfour/checkpoint/policies/always_same/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/always_same/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/beat_last/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/learned/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/learned/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/learned_v1/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/learned_v2/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/learned_v3/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/learned_v4/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/learned_v5/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/linear/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/linear/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/policies/random/policy_state.pkl +3 -0
- connectfour/checkpoint/policies/random/rllib_checkpoint.json +1 -0
- connectfour/checkpoint/rllib_checkpoint.json +1 -0
- connectfour/training/__init__.py +0 -0
- connectfour/training/__pycache__/__init__.cpython-38.pyc +0 -0
- connectfour/training/__pycache__/callbacks.cpython-38.pyc +0 -0
- connectfour/training/__pycache__/dummy_policies.cpython-38.pyc +0 -0
- connectfour/training/__pycache__/models.cpython-38.pyc +0 -0
- connectfour/training/__pycache__/wrappers.cpython-38.pyc +0 -0
- connectfour/training/callbacks.py +88 -0
- connectfour/training/dummy_policies.py +130 -0
- connectfour/training/models.py +119 -0
- connectfour/training/train.py +140 -0
- connectfour/training/wrappers.py +112 -0
- poetry.lock +0 -0
- pyproject.toml +37 -0
- requirements.txt +141 -0
.vscode/extensions.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// See https://go.microsoft.com/fwlink/?LinkId=827846
|
3 |
+
// for the documentation about the extensions.json format
|
4 |
+
"recommendations": [
|
5 |
+
"ms-python.python"
|
6 |
+
]
|
7 |
+
}
|
.vscode/settings.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.linting.enabled": true,
|
3 |
+
"python.linting.mypyEnabled": true,
|
4 |
+
"python.pythonPath": "${env:PYTHON_VENV_LOC}",
|
5 |
+
"python.testing.unittestEnabled": false,
|
6 |
+
"python.testing.nosetestsEnabled": false,
|
7 |
+
"python.testing.pytestEnabled": true,
|
8 |
+
"python.testing.pytestArgs": [
|
9 |
+
"tests"
|
10 |
+
],
|
11 |
+
"editor.tabSize": 4,
|
12 |
+
"[python]": {
|
13 |
+
"editor.formatOnSave": true
|
14 |
+
},
|
15 |
+
"python.formatting.provider": "black",
|
16 |
+
"files.exclude": {
|
17 |
+
".mypy_cache": true,
|
18 |
+
".pytest_cache": true,
|
19 |
+
".venv": true,
|
20 |
+
"**/__pycache__": true
|
21 |
+
},
|
22 |
+
"files.watcherExclude": {
|
23 |
+
".venv/**": true,
|
24 |
+
"**/__pycache__/**": true,
|
25 |
+
".mypy_cache/**": true,
|
26 |
+
".pytest_cache/**": true
|
27 |
+
}
|
28 |
+
}
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.23.0
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.23.0
|
8 |
+
app_file: connectfour/app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
connectfour/__init__.py
ADDED
File without changes
|
connectfour/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
connectfour/__pycache__/app.cpython-38.pyc
ADDED
Binary file (6.51 kB). View file
|
|
connectfour/app.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from ray.serve.gradio_integrations import GradioServer, GradioIngress
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from pettingzoo.classic import connect_four_v3
|
6 |
+
import ray.rllib.algorithms.ppo as ppo
|
7 |
+
import numpy as np
|
8 |
+
import time
|
9 |
+
from ray.tune import register_env
|
10 |
+
from connectfour.training.models import Connect4MaskModel
|
11 |
+
from connectfour.checkpoint import CHECKPOINT
|
12 |
+
|
13 |
+
from connectfour.training.wrappers import Connect4Env
|
14 |
+
|
15 |
+
demo = gr.Blocks()
|
16 |
+
|
17 |
+
POLICY_ID = "learned_v5"
|
18 |
+
|
19 |
+
|
20 |
+
class Connect4:
|
21 |
+
def __init__(self, who_plays_first) -> None:
|
22 |
+
self.init_env(who_plays_first)
|
23 |
+
|
24 |
+
def init_env(self, who_plays_first):
|
25 |
+
# define how to make the environment
|
26 |
+
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
|
27 |
+
|
28 |
+
# register that way to make the environment under an rllib name
|
29 |
+
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
|
30 |
+
|
31 |
+
orig_env = connect_four_v3.env(render_mode="rgb_array")
|
32 |
+
self.env = Connect4Env(orig_env)
|
33 |
+
|
34 |
+
self.done = False
|
35 |
+
self.obs, info = self.env.reset()
|
36 |
+
|
37 |
+
if who_plays_first == "You":
|
38 |
+
self.human = self.player_id
|
39 |
+
else:
|
40 |
+
self.play()
|
41 |
+
self.human = self.player_id
|
42 |
+
|
43 |
+
return self.render_and_state
|
44 |
+
|
45 |
+
def get_algo(self, checkpoint):
|
46 |
+
config = (
|
47 |
+
ppo.PPOConfig()
|
48 |
+
.environment("connect4")
|
49 |
+
.framework("torch")
|
50 |
+
.training(model={"custom_model": Connect4MaskModel})
|
51 |
+
)
|
52 |
+
config.explore = False
|
53 |
+
self.algo = config.build()
|
54 |
+
self.algo.restore(checkpoint)
|
55 |
+
|
56 |
+
def play(self, action=None):
|
57 |
+
if self.human != self.player_id:
|
58 |
+
action = self.algo.compute_single_action(
|
59 |
+
self.obs[self.player_id], policy_id=POLICY_ID
|
60 |
+
)
|
61 |
+
|
62 |
+
if action not in self.legal_moves:
|
63 |
+
action = np.random.choice(self.legal_moves)
|
64 |
+
|
65 |
+
player_actions = {self.player_id: action}
|
66 |
+
|
67 |
+
self.obs, self.reward, terminated, truncated, info = self.env.step(
|
68 |
+
player_actions
|
69 |
+
)
|
70 |
+
self.done = terminated["__all__"] or truncated["__all__"]
|
71 |
+
return self.render_and_state
|
72 |
+
|
73 |
+
@property
|
74 |
+
def render_and_state(self):
|
75 |
+
end_message = "End of the game"
|
76 |
+
if self.done:
|
77 |
+
if self.reward[self.human] > 0:
|
78 |
+
end_message += ": You WIN !!"
|
79 |
+
elif self.reward[self.human] < 0:
|
80 |
+
end_message += ": You LOSE !!"
|
81 |
+
return self.env.render(), end_message
|
82 |
+
|
83 |
+
return self.env.render(), "Game On"
|
84 |
+
|
85 |
+
@property
|
86 |
+
def player_id(self):
|
87 |
+
return list(self.obs.keys())[0]
|
88 |
+
|
89 |
+
@property
|
90 |
+
def legal_moves(self):
|
91 |
+
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
|
92 |
+
|
93 |
+
|
94 |
+
with demo:
|
95 |
+
connect4 = Connect4("You")
|
96 |
+
connect4.get_algo(str(CHECKPOINT))
|
97 |
+
|
98 |
+
with gr.Row():
|
99 |
+
with gr.Column(scale=1):
|
100 |
+
gr.Markdown("# Let's Play Connect Four !")
|
101 |
+
|
102 |
+
who_plays_first = gr.Radio(
|
103 |
+
label="Who plays first", choices=["You", "Bot"], value="You"
|
104 |
+
)
|
105 |
+
reinitialize = gr.Button("New Game")
|
106 |
+
|
107 |
+
game_state = gr.Text(value="Game On", interactive=False, label="Status")
|
108 |
+
|
109 |
+
with gr.Column(scale=1):
|
110 |
+
output = gr.Image(
|
111 |
+
label="Connect Four Grid",
|
112 |
+
type="numpy",
|
113 |
+
show_label=False,
|
114 |
+
value=connect4.env.render(),
|
115 |
+
)
|
116 |
+
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column(scale=1, min_width=20):
|
119 |
+
drop_token0_btn = gr.Button("X")
|
120 |
+
with gr.Column(scale=1, min_width=20):
|
121 |
+
drop_token1_btn = gr.Button("X")
|
122 |
+
with gr.Column(scale=1, min_width=20):
|
123 |
+
drop_token2_btn = gr.Button("X")
|
124 |
+
with gr.Column(scale=1, min_width=20):
|
125 |
+
drop_token3_btn = gr.Button("X")
|
126 |
+
with gr.Column(scale=1, min_width=20):
|
127 |
+
drop_token4_btn = gr.Button("X")
|
128 |
+
with gr.Column(scale=1, min_width=20):
|
129 |
+
drop_token5_btn = gr.Button("X")
|
130 |
+
with gr.Column(scale=1, min_width=20):
|
131 |
+
drop_token6_btn = gr.Button("X")
|
132 |
+
|
133 |
+
who_plays_first.change(
|
134 |
+
connect4.init_env, who_plays_first, outputs=[output, game_state]
|
135 |
+
)
|
136 |
+
|
137 |
+
def reinit_game(who_plays_first):
|
138 |
+
output, game_state = connect4.init_env(who_plays_first)
|
139 |
+
return output, game_state, gr.Checkbox.update(interactive=True)
|
140 |
+
|
141 |
+
reinitialize.click(
|
142 |
+
reinit_game, who_plays_first, outputs=[output, game_state, who_plays_first]
|
143 |
+
)
|
144 |
+
|
145 |
+
def wait(game_state_value):
|
146 |
+
if game_state_value == "Game On":
|
147 |
+
time.sleep(1)
|
148 |
+
return gr.Checkbox.update(interactive=False)
|
149 |
+
else:
|
150 |
+
return gr.Checkbox.update(interactive=True)
|
151 |
+
|
152 |
+
def bot(game_state_value):
|
153 |
+
if game_state_value == "Game On":
|
154 |
+
rendered_env = connect4.play()
|
155 |
+
return *rendered_env, gr.Checkbox.update(interactive=False) if rendered_env[
|
156 |
+
1
|
157 |
+
] == "Game On" else gr.Checkbox.update(interactive=True)
|
158 |
+
return (
|
159 |
+
gr.Image.update(),
|
160 |
+
game_state_value,
|
161 |
+
gr.Checkbox.update(interactive=True),
|
162 |
+
)
|
163 |
+
|
164 |
+
drop_token0_btn.click(
|
165 |
+
lambda: connect4.play(0),
|
166 |
+
outputs=[output, game_state],
|
167 |
+
).then(
|
168 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
169 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
170 |
+
|
171 |
+
drop_token1_btn.click(
|
172 |
+
lambda: connect4.play(1),
|
173 |
+
outputs=[output, game_state],
|
174 |
+
).then(
|
175 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
176 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
177 |
+
|
178 |
+
drop_token2_btn.click(
|
179 |
+
lambda: connect4.play(2),
|
180 |
+
outputs=[output, game_state],
|
181 |
+
).then(
|
182 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
183 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
184 |
+
|
185 |
+
drop_token3_btn.click(
|
186 |
+
lambda: connect4.play(3),
|
187 |
+
outputs=[output, game_state],
|
188 |
+
).then(
|
189 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
190 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
191 |
+
|
192 |
+
drop_token4_btn.click(
|
193 |
+
lambda: connect4.play(4),
|
194 |
+
outputs=[output, game_state],
|
195 |
+
).then(
|
196 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
197 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
198 |
+
|
199 |
+
drop_token5_btn.click(
|
200 |
+
lambda: connect4.play(5),
|
201 |
+
outputs=[output, game_state],
|
202 |
+
).then(
|
203 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
204 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
205 |
+
|
206 |
+
drop_token6_btn.click(
|
207 |
+
lambda: connect4.play(6),
|
208 |
+
outputs=[output, game_state],
|
209 |
+
).then(
|
210 |
+
wait, inputs=[game_state], outputs=who_plays_first
|
211 |
+
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
|
212 |
+
|
213 |
+
def game_state_change(value):
|
214 |
+
if value == "Game On":
|
215 |
+
return [
|
216 |
+
gr.Button.update(interactive=True),
|
217 |
+
gr.Button.update(interactive=True),
|
218 |
+
gr.Button.update(interactive=True),
|
219 |
+
gr.Button.update(interactive=True),
|
220 |
+
gr.Button.update(interactive=True),
|
221 |
+
gr.Button.update(interactive=True),
|
222 |
+
gr.Button.update(interactive=True),
|
223 |
+
]
|
224 |
+
else:
|
225 |
+
return [
|
226 |
+
gr.Button.update(interactive=False),
|
227 |
+
gr.Button.update(interactive=False),
|
228 |
+
gr.Button.update(interactive=False),
|
229 |
+
gr.Button.update(interactive=False),
|
230 |
+
gr.Button.update(interactive=False),
|
231 |
+
gr.Button.update(interactive=False),
|
232 |
+
gr.Button.update(interactive=False),
|
233 |
+
]
|
234 |
+
|
235 |
+
game_state.change(
|
236 |
+
game_state_change,
|
237 |
+
game_state,
|
238 |
+
outputs=[
|
239 |
+
drop_token0_btn,
|
240 |
+
drop_token1_btn,
|
241 |
+
drop_token2_btn,
|
242 |
+
drop_token3_btn,
|
243 |
+
drop_token4_btn,
|
244 |
+
drop_token5_btn,
|
245 |
+
drop_token6_btn,
|
246 |
+
],
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
demo.launch()
|
connectfour/checkpoint/.Rhistory
ADDED
File without changes
|
connectfour/checkpoint/.is_checkpoint
ADDED
File without changes
|
connectfour/checkpoint/.tune_metadata
ADDED
Binary file (15.6 kB). View file
|
|
connectfour/checkpoint/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
CHECKPOINT = Path(__file__).parent.absolute()
|
connectfour/checkpoint/algorithm_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbbc198c3406897931f5f18046a88181b8abff1aedbea1d869329731c9a50853
|
3 |
+
size 66321
|
connectfour/checkpoint/policies/always_same/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d278413093ad1bc4f227279e3dab7be04ebd70ca1ed156a1363515c69d0a858e
|
3 |
+
size 10992
|
connectfour/checkpoint/policies/always_same/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/beat_last/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd422258c16de0866599730a5a5b2b48e2ee81cbae69f9d5471deeae76c42b47
|
3 |
+
size 10992
|
connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/learned/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a517583e5fcad7e483bca619723583cc6928499390c1fcfc25d907e109cd4b4
|
3 |
+
size 2139442
|
connectfour/checkpoint/policies/learned/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/learned_v1/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:276c26007c2419a688c27f9dfa70c20fecb468a0aa07d28d6a9e8099bbc849be
|
3 |
+
size 2139439
|
connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/learned_v2/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e37a485d3a54f7a8b194693e7a61f790e67071358130178fa01cdbd840c4a4da
|
3 |
+
size 2139439
|
connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/learned_v3/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f90899ae98a387e312333b234041c68b9c50da4af92ee5250686087a39eebb3
|
3 |
+
size 2139439
|
connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/learned_v4/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3af3b3fe41bac489cb693af387b1ccc4437a532a78d539b3abb4cc5f77929592
|
3 |
+
size 2139439
|
connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/learned_v5/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2b28b979e2f4411d196e03ca75ea7f25f7601bb997aa8bcdcf1d49c9ea30754
|
3 |
+
size 2139439
|
connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/linear/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f70d44ac661632dc0557204abe34308dfb25b800a668b49c2efd9a2a73a7bc0
|
3 |
+
size 10992
|
connectfour/checkpoint/policies/linear/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/policies/random/policy_state.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3b1ab86bada035779feedb2b92ae0a64f6d9474bb4f0ae44324e17d65659764
|
3 |
+
size 10992
|
connectfour/checkpoint/policies/random/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/checkpoint/rllib_checkpoint.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"type": "Algorithm", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
connectfour/training/__init__.py
ADDED
File without changes
|
connectfour/training/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (182 Bytes). View file
|
|
connectfour/training/__pycache__/callbacks.cpython-38.pyc
ADDED
Binary file (2.67 kB). View file
|
|
connectfour/training/__pycache__/dummy_policies.cpython-38.pyc
ADDED
Binary file (5.36 kB). View file
|
|
connectfour/training/__pycache__/models.cpython-38.pyc
ADDED
Binary file (2.91 kB). View file
|
|
connectfour/training/__pycache__/wrappers.cpython-38.pyc
ADDED
Binary file (3.72 kB). View file
|
|
connectfour/training/callbacks.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def create_self_play_callback(win_rate_thr, opponent_policies):
|
6 |
+
class SelfPlayCallback(DefaultCallbacks):
|
7 |
+
win_rate_threshold = win_rate_thr
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.current_opponent = 0
|
12 |
+
|
13 |
+
def on_train_result(self, *, algorithm, result, **kwargs):
|
14 |
+
# Get the win rate for the train batch.
|
15 |
+
# Note that normally, one should set up a proper evaluation config,
|
16 |
+
# such that evaluation always happens on the already updated policy,
|
17 |
+
# instead of on the already used train_batch.
|
18 |
+
main_rew = result["hist_stats"].pop("policy_learned_reward")
|
19 |
+
opponent_rew = result["hist_stats"].pop("episode_reward")
|
20 |
+
|
21 |
+
if len(main_rew) != len(opponent_rew):
|
22 |
+
raise Exception(
|
23 |
+
"len(main_rew) != len(opponent_rew)",
|
24 |
+
len(main_rew),
|
25 |
+
len(opponent_rew),
|
26 |
+
result["hist_stats"].keys(),
|
27 |
+
"episode len",
|
28 |
+
len(opponent_rew),
|
29 |
+
)
|
30 |
+
|
31 |
+
won = 0
|
32 |
+
for r_main, r_opponent in zip(main_rew, opponent_rew):
|
33 |
+
if r_main > r_opponent:
|
34 |
+
won += 1
|
35 |
+
win_rate = won / len(main_rew)
|
36 |
+
|
37 |
+
result["win_rate"] = win_rate
|
38 |
+
print(f"Iter={algorithm.iteration} win-rate={win_rate} -> ", end="")
|
39 |
+
|
40 |
+
# If win rate is good -> Snapshot current policy and play against
|
41 |
+
# it next, keeping the snapshot fixed and only improving the "learned"
|
42 |
+
# policy.
|
43 |
+
if win_rate > self.win_rate_threshold:
|
44 |
+
self.current_opponent += 1
|
45 |
+
new_pol_id = f"learned_v{self.current_opponent}"
|
46 |
+
print(
|
47 |
+
f"Iter={algorithm.iteration} ### Adding new opponent to the mix ({new_pol_id})."
|
48 |
+
)
|
49 |
+
|
50 |
+
# Re-define the mapping function, such that "learned" is forced
|
51 |
+
# to play against any of the previously played policies
|
52 |
+
# (excluding "random").
|
53 |
+
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
54 |
+
# agent_id = [0|1] -> policy depends on episode ID
|
55 |
+
# This way, we make sure that both policies sometimes play
|
56 |
+
# (start player) and sometimes agent1 (player to move 2nd).
|
57 |
+
return (
|
58 |
+
"learned"
|
59 |
+
if episode.episode_id % 2 == int(agent_id[-1:])
|
60 |
+
else np.random.choice(
|
61 |
+
opponent_policies
|
62 |
+
+ [
|
63 |
+
f"learned_v{i}"
|
64 |
+
for i in range(1, self.current_opponent + 1)
|
65 |
+
]
|
66 |
+
)
|
67 |
+
)
|
68 |
+
|
69 |
+
new_policy = algorithm.add_policy(
|
70 |
+
policy_id=new_pol_id,
|
71 |
+
policy_cls=type(algorithm.get_policy("learned")),
|
72 |
+
policy_mapping_fn=policy_mapping_fn,
|
73 |
+
)
|
74 |
+
|
75 |
+
# Set the weights of the new policy to the learned policy.
|
76 |
+
# We'll keep training the learned policy, whereas `new_pol_id` will
|
77 |
+
# remain fixed.
|
78 |
+
learned_state = algorithm.get_policy("learned").get_state()
|
79 |
+
new_policy.set_state(learned_state)
|
80 |
+
# We need to sync the just copied local weights (from learned policy)
|
81 |
+
# to all the remote workers as well.
|
82 |
+
algorithm.workers.sync_weights()
|
83 |
+
else:
|
84 |
+
print("not good enough; will keep learning ...")
|
85 |
+
|
86 |
+
result["league_size"] = self.current_opponent + len(opponent_policies) + 1
|
87 |
+
|
88 |
+
return SelfPlayCallback
|
connectfour/training/dummy_policies.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from ray.rllib.policy.policy import Policy
|
4 |
+
from ray.rllib.utils.annotations import override
|
5 |
+
from ray.rllib.models.modelv2 import restore_original_dimensions
|
6 |
+
|
7 |
+
|
8 |
+
class HeuristicBase(Policy):
|
9 |
+
def __init__(self, *args, **kwargs):
|
10 |
+
super().__init__(*args, **kwargs)
|
11 |
+
self.exploration = self._create_exploration()
|
12 |
+
|
13 |
+
def learn_on_batch(self, samples):
|
14 |
+
pass
|
15 |
+
|
16 |
+
@override(Policy)
|
17 |
+
def get_weights(self):
|
18 |
+
"""No weights to save."""
|
19 |
+
return {}
|
20 |
+
|
21 |
+
@override(Policy)
|
22 |
+
def set_weights(self, weights):
|
23 |
+
"""No weights to set."""
|
24 |
+
pass
|
25 |
+
|
26 |
+
@override(Policy)
|
27 |
+
def compute_actions(
|
28 |
+
self,
|
29 |
+
obs_batch,
|
30 |
+
state_batches=None,
|
31 |
+
prev_action_batch=None,
|
32 |
+
prev_reward_batch=None,
|
33 |
+
info_batch=None,
|
34 |
+
episodes=None,
|
35 |
+
**kwargs
|
36 |
+
):
|
37 |
+
obs_batch = restore_original_dimensions(
|
38 |
+
np.array(obs_batch, dtype=np.float32), self.observation_space, tensorlib=np
|
39 |
+
)
|
40 |
+
return self._do_compute_actions(obs_batch)
|
41 |
+
|
42 |
+
def pick_legal_action(self, legal_action):
|
43 |
+
legal_choices = np.arange(len(legal_action))[legal_action == 1]
|
44 |
+
return np.random.choice(legal_choices)
|
45 |
+
|
46 |
+
|
47 |
+
class AlwaysSameHeuristic(HeuristicBase):
|
48 |
+
"""
|
49 |
+
Pick a random move and stick with it for the entire episode.
|
50 |
+
"""
|
51 |
+
|
52 |
+
_rand_choice = random.choice(range(7))
|
53 |
+
|
54 |
+
def _do_compute_actions(self, obs_batch):
|
55 |
+
def select_action(legal_action):
|
56 |
+
legal_choices = np.arange(len(legal_action))[legal_action == 1]
|
57 |
+
|
58 |
+
if self._rand_choice not in legal_choices:
|
59 |
+
self._rand_choice = np.random.choice(legal_choices)
|
60 |
+
|
61 |
+
return self._rand_choice
|
62 |
+
|
63 |
+
return [select_action(x) for x in obs_batch["action_mask"]], [], {}
|
64 |
+
|
65 |
+
|
66 |
+
class LinearHeuristic(HeuristicBase):
|
67 |
+
"""
|
68 |
+
Pick a random move and increment column index
|
69 |
+
"""
|
70 |
+
|
71 |
+
_rand_choice = random.choice(range(7))
|
72 |
+
_rand_sign = np.random.choice([-1, 1])
|
73 |
+
|
74 |
+
def _do_compute_actions(self, obs_batch):
|
75 |
+
def select_action(legal_action):
|
76 |
+
legal_choices = np.arange(len(legal_action))[legal_action == 1]
|
77 |
+
|
78 |
+
self._rand_choice += 1 * self._rand_sign
|
79 |
+
|
80 |
+
if self._rand_choice not in legal_choices:
|
81 |
+
self._rand_choice = np.random.choice(legal_choices)
|
82 |
+
|
83 |
+
return self._rand_choice
|
84 |
+
|
85 |
+
return [select_action(x) for x in obs_batch["action_mask"]], [], {}
|
86 |
+
|
87 |
+
|
88 |
+
class BeatLastHeuristic(HeuristicBase):
|
89 |
+
"""
|
90 |
+
Play the move the last move of the opponent.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def _do_compute_actions(self, obs_batch):
|
94 |
+
def select_action(legal_action, observation):
|
95 |
+
legal_choices = np.arange(len(legal_action))[legal_action == 1]
|
96 |
+
|
97 |
+
obs_sums = np.sum(observation, axis=0)
|
98 |
+
|
99 |
+
desired_actions = np.squeeze(np.argwhere(obs_sums[:, 0] < obs_sums[:, 1]))
|
100 |
+
if desired_actions.size == 0:
|
101 |
+
return np.random.choice(legal_choices)
|
102 |
+
|
103 |
+
if desired_actions.size == 1:
|
104 |
+
desired_action = desired_actions[()]
|
105 |
+
else:
|
106 |
+
desired_action = np.random.choice(desired_actions)
|
107 |
+
if desired_action in legal_choices:
|
108 |
+
return desired_action
|
109 |
+
|
110 |
+
return np.random.choice(legal_choices)
|
111 |
+
|
112 |
+
return (
|
113 |
+
[
|
114 |
+
select_action(x, y)
|
115 |
+
for x, y in zip(obs_batch["action_mask"], obs_batch["observation"])
|
116 |
+
],
|
117 |
+
[],
|
118 |
+
{},
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
class RandomHeuristic(HeuristicBase):
|
123 |
+
"""
|
124 |
+
Just pick a random legal action
|
125 |
+
The outputted state of the environment needs to be a dictionary with an
|
126 |
+
'action_mask' key containing the legal actions for the agent.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def _do_compute_actions(self, obs_batch):
|
130 |
+
return [self.pick_legal_action(x) for x in obs_batch["action_mask"]], [], {}
|
connectfour/training/models.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
|
2 |
+
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
3 |
+
from gymnasium.spaces import Dict
|
4 |
+
from ray.rllib.utils.torch_utils import FLOAT_MIN
|
5 |
+
from ray.rllib.utils.framework import try_import_torch
|
6 |
+
from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel
|
7 |
+
from ray.rllib.utils import override
|
8 |
+
|
9 |
+
torch, nn = try_import_torch()
|
10 |
+
|
11 |
+
|
12 |
+
class Connect4MaskModel(TorchModelV2, nn.Module):
|
13 |
+
"""PyTorch version of above ActionMaskingModel."""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
obs_space,
|
18 |
+
action_space,
|
19 |
+
num_outputs,
|
20 |
+
model_config,
|
21 |
+
name,
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
+
orig_space = getattr(obs_space, "original_space", obs_space)
|
25 |
+
|
26 |
+
assert isinstance(orig_space, Dict)
|
27 |
+
assert "action_mask" in orig_space.spaces
|
28 |
+
assert "observation" in orig_space.spaces
|
29 |
+
|
30 |
+
TorchModelV2.__init__(
|
31 |
+
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
|
32 |
+
)
|
33 |
+
nn.Module.__init__(self)
|
34 |
+
|
35 |
+
self.internal_model = TorchFC(
|
36 |
+
orig_space["observation"],
|
37 |
+
action_space,
|
38 |
+
num_outputs,
|
39 |
+
model_config,
|
40 |
+
name + "_internal",
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, input_dict, state, seq_lens):
|
44 |
+
# Extract the available actions tensor from the observation.
|
45 |
+
action_mask = input_dict["obs"]["action_mask"]
|
46 |
+
|
47 |
+
# Compute the unmasked logits.
|
48 |
+
logits, _ = self.internal_model({"obs": input_dict["obs"]["observation"]})
|
49 |
+
|
50 |
+
# Convert action_mask into a [0.0 || -inf]-type mask.
|
51 |
+
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
|
52 |
+
masked_logits = logits + inf_mask
|
53 |
+
|
54 |
+
# Return masked logits.
|
55 |
+
return masked_logits, state
|
56 |
+
|
57 |
+
def value_function(self):
|
58 |
+
return self.internal_model.value_function()
|
59 |
+
|
60 |
+
|
61 |
+
class SacConnect4MaskModel(SACTorchModel):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
obs_space,
|
65 |
+
action_space,
|
66 |
+
num_outputs,
|
67 |
+
model_config,
|
68 |
+
name: str,
|
69 |
+
policy_model_config=None,
|
70 |
+
q_model_config=None,
|
71 |
+
twin_q=False,
|
72 |
+
initial_alpha=1.0,
|
73 |
+
target_entropy=None,
|
74 |
+
**kwargs,
|
75 |
+
):
|
76 |
+
orig_space = getattr(obs_space, "original_space", obs_space)
|
77 |
+
|
78 |
+
assert isinstance(orig_space, Dict)
|
79 |
+
assert "action_mask" in orig_space.spaces
|
80 |
+
assert "observation" in orig_space.spaces
|
81 |
+
|
82 |
+
super().__init__(
|
83 |
+
obs_space,
|
84 |
+
action_space,
|
85 |
+
num_outputs,
|
86 |
+
model_config,
|
87 |
+
policy_model_config,
|
88 |
+
q_model_config,
|
89 |
+
twin_q,
|
90 |
+
initial_alpha,
|
91 |
+
target_entropy,
|
92 |
+
**kwargs,
|
93 |
+
)
|
94 |
+
|
95 |
+
self.internal_model = TorchFC(
|
96 |
+
orig_space["observation"],
|
97 |
+
action_space,
|
98 |
+
num_outputs,
|
99 |
+
model_config,
|
100 |
+
name + "_internal",
|
101 |
+
)
|
102 |
+
|
103 |
+
@override(SACTorchModel)
|
104 |
+
def forward(self, input_dict, state, seq_lens):
|
105 |
+
# Extract the available actions tensor from the observation.
|
106 |
+
action_mask = input_dict["obs"]["action_mask"]
|
107 |
+
|
108 |
+
# Compute the unmasked logits.
|
109 |
+
logits, _ = self.internal_model({"obs": input_dict["obs"]["observation"]})
|
110 |
+
|
111 |
+
# Convert action_mask into a [0.0 || -inf]-type mask.
|
112 |
+
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
|
113 |
+
masked_logits = logits + inf_mask
|
114 |
+
|
115 |
+
# Return masked logits.
|
116 |
+
return masked_logits, state
|
117 |
+
|
118 |
+
def value_function(self):
|
119 |
+
return self.internal_model.value_function()
|
connectfour/training/train.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
|
4 |
+
import ray
|
5 |
+
import ray.rllib.algorithms.ppo as ppo
|
6 |
+
from pettingzoo.classic import connect_four_v3
|
7 |
+
from ray import air, tune
|
8 |
+
from ray.rllib.policy.policy import PolicySpec
|
9 |
+
from ray.rllib.utils.framework import try_import_torch
|
10 |
+
from ray.tune import CLIReporter, register_env
|
11 |
+
|
12 |
+
from connectfour.training.callbacks import create_self_play_callback
|
13 |
+
from connectfour.training.dummy_policies import (
|
14 |
+
AlwaysSameHeuristic,
|
15 |
+
BeatLastHeuristic,
|
16 |
+
LinearHeuristic,
|
17 |
+
RandomHeuristic,
|
18 |
+
)
|
19 |
+
from connectfour.training.models import Connect4MaskModel
|
20 |
+
from connectfour.training.wrappers import Connect4Env
|
21 |
+
|
22 |
+
torch, nn = try_import_torch()
|
23 |
+
|
24 |
+
|
25 |
+
def get_cli_args():
|
26 |
+
"""
|
27 |
+
Create CLI parser and return parsed arguments
|
28 |
+
|
29 |
+
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
30 |
+
python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
31 |
+
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
32 |
+
"""
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument("--num-cpus", type=int, default=0)
|
35 |
+
parser.add_argument("--num-gpus", type=int, default=0)
|
36 |
+
parser.add_argument("--num-workers", type=int, default=2)
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--stop-iters", type=int, default=200, help="Number of iterations to train."
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--stop-timesteps",
|
43 |
+
type=int,
|
44 |
+
default=10000000,
|
45 |
+
help="Number of timesteps to train.",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--win-rate-threshold",
|
49 |
+
type=float,
|
50 |
+
default=0.95,
|
51 |
+
help="Win-rate at which we setup another opponent by freezing the "
|
52 |
+
"current main policy and playing against a uniform distribution "
|
53 |
+
"of previously frozen 'main's from here on.",
|
54 |
+
)
|
55 |
+
args = parser.parse_args()
|
56 |
+
print(f"Running with following CLI args: {args}")
|
57 |
+
return args
|
58 |
+
|
59 |
+
|
60 |
+
def select_policy(agent_id, episode, **kwargs):
|
61 |
+
if episode.episode_id % 2 == int(agent_id[-1:]):
|
62 |
+
return "learned"
|
63 |
+
else:
|
64 |
+
return random.choice(["always_same", "beat_last", "random", "linear"])
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
args = get_cli_args()
|
69 |
+
|
70 |
+
ray.init(
|
71 |
+
num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False
|
72 |
+
)
|
73 |
+
|
74 |
+
# define how to make the environment
|
75 |
+
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
|
76 |
+
|
77 |
+
# register that way to make the environment under an rllib name
|
78 |
+
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
|
79 |
+
|
80 |
+
config = (
|
81 |
+
ppo.PPOConfig()
|
82 |
+
.environment("connect4")
|
83 |
+
.framework("torch")
|
84 |
+
.training(model={"custom_model": Connect4MaskModel})
|
85 |
+
.callbacks(
|
86 |
+
create_self_play_callback(
|
87 |
+
win_rate_thr=args.win_rate_threshold,
|
88 |
+
opponent_policies=["always_same", "beat_last", "random", "linear"],
|
89 |
+
)
|
90 |
+
)
|
91 |
+
.rollouts(
|
92 |
+
num_rollout_workers=args.num_workers,
|
93 |
+
num_envs_per_worker=5,
|
94 |
+
)
|
95 |
+
.multi_agent(
|
96 |
+
policies={
|
97 |
+
"learned": PolicySpec(),
|
98 |
+
"always_same": PolicySpec(policy_class=AlwaysSameHeuristic),
|
99 |
+
"linear": PolicySpec(policy_class=LinearHeuristic),
|
100 |
+
"beat_last": PolicySpec(policy_class=BeatLastHeuristic),
|
101 |
+
"random": PolicySpec(policy_class=RandomHeuristic),
|
102 |
+
},
|
103 |
+
policy_mapping_fn=select_policy,
|
104 |
+
policies_to_train=["learned"],
|
105 |
+
)
|
106 |
+
)
|
107 |
+
|
108 |
+
stop = {
|
109 |
+
"timesteps_total": args.stop_timesteps,
|
110 |
+
"training_iteration": args.stop_iters,
|
111 |
+
}
|
112 |
+
|
113 |
+
results = tune.Tuner(
|
114 |
+
"PPO",
|
115 |
+
param_space=config.to_dict(),
|
116 |
+
run_config=air.RunConfig(
|
117 |
+
stop=stop,
|
118 |
+
verbose=2,
|
119 |
+
progress_reporter=CLIReporter(
|
120 |
+
metric_columns={
|
121 |
+
"training_iteration": "iter",
|
122 |
+
"time_total_s": "time_total_s",
|
123 |
+
"timesteps_total": "ts",
|
124 |
+
"episodes_this_iter": "train_episodes",
|
125 |
+
"policy_reward_mean/learned": "reward",
|
126 |
+
"win_rate": "win_rate",
|
127 |
+
"league_size": "league_size",
|
128 |
+
},
|
129 |
+
sort_by_metric=True,
|
130 |
+
),
|
131 |
+
checkpoint_config=air.CheckpointConfig(
|
132 |
+
checkpoint_at_end=True,
|
133 |
+
checkpoint_frequency=10,
|
134 |
+
),
|
135 |
+
),
|
136 |
+
).fit()
|
137 |
+
|
138 |
+
print("Best checkpoint", results.get_best_result().checkpoint)
|
139 |
+
|
140 |
+
ray.shutdown()
|
connectfour/training/wrappers.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
4 |
+
from ray.rllib.utils.annotations import PublicAPI
|
5 |
+
from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space
|
6 |
+
|
7 |
+
|
8 |
+
@PublicAPI
|
9 |
+
class Connect4Env(MultiAgentEnv):
|
10 |
+
"""An interface to the PettingZoo MARL environment library.
|
11 |
+
|
12 |
+
See: https://github.com/Farama-Foundation/PettingZoo
|
13 |
+
|
14 |
+
Inherits from MultiAgentEnv and exposes a given AEC
|
15 |
+
(actor-environment-cycle) game from the PettingZoo project via the
|
16 |
+
MultiAgentEnv public API.
|
17 |
+
|
18 |
+
Note that the wrapper has some important limitations:
|
19 |
+
|
20 |
+
1. All agents have the same action_spaces and observation_spaces.
|
21 |
+
Note: If, within your aec game, agents do not have homogeneous action /
|
22 |
+
observation spaces, apply SuperSuit wrappers
|
23 |
+
to apply padding functionality: https://github.com/Farama-Foundation/
|
24 |
+
SuperSuit#built-in-multi-agent-only-functions
|
25 |
+
2. Environments are positive sum games (-> Agents are expected to cooperate
|
26 |
+
to maximize reward). This isn't a hard restriction, it just that
|
27 |
+
standard algorithms aren't expected to work well in highly competitive
|
28 |
+
games."""
|
29 |
+
|
30 |
+
def __init__(self, env):
|
31 |
+
super().__init__()
|
32 |
+
self.env = env
|
33 |
+
env.reset()
|
34 |
+
|
35 |
+
# Since all agents have the same spaces, do not provide full observation-
|
36 |
+
# and action-spaces as Dicts, mapping agent IDs to the individual
|
37 |
+
# agents' spaces. Instead, `self.[action|observation]_space` are the single
|
38 |
+
# agent spaces.
|
39 |
+
self._obs_space_in_preferred_format = False
|
40 |
+
self._action_space_in_preferred_format = False
|
41 |
+
|
42 |
+
# Collect the individual agents' spaces (they should all be the same):
|
43 |
+
first_obs_space = self.env.observation_space(self.env.agents[0])
|
44 |
+
first_action_space = self.env.action_space(self.env.agents[0])
|
45 |
+
|
46 |
+
for agent in self.env.agents:
|
47 |
+
if self.env.observation_space(agent) != first_obs_space:
|
48 |
+
raise ValueError(
|
49 |
+
"Observation spaces for all agents must be identical. Perhaps "
|
50 |
+
"SuperSuit's pad_observations wrapper can help (useage: "
|
51 |
+
"`supersuit.aec_wrappers.pad_observations(env)`"
|
52 |
+
)
|
53 |
+
if self.env.action_space(agent) != first_action_space:
|
54 |
+
raise ValueError(
|
55 |
+
"Action spaces for all agents must be identical. Perhaps "
|
56 |
+
"SuperSuit's pad_action_space wrapper can help (usage: "
|
57 |
+
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
58 |
+
)
|
59 |
+
|
60 |
+
# Convert from gym to gymnasium, if necessary.
|
61 |
+
self.observation_space = convert_old_gym_space_to_gymnasium_space(
|
62 |
+
first_obs_space
|
63 |
+
)
|
64 |
+
self.action_space = convert_old_gym_space_to_gymnasium_space(first_action_space)
|
65 |
+
|
66 |
+
self._agent_ids = set(self.env.agents)
|
67 |
+
|
68 |
+
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
69 |
+
info = self.env.reset(seed=seed, options=options)
|
70 |
+
return (
|
71 |
+
{self.env.agent_selection: self.env.observe(self.env.agent_selection)},
|
72 |
+
info or {},
|
73 |
+
)
|
74 |
+
|
75 |
+
def step(self, action):
|
76 |
+
self.env.step(action[self.env.agent_selection])
|
77 |
+
obs_d = {}
|
78 |
+
rew_d = {}
|
79 |
+
terminated_d = {}
|
80 |
+
truncated_d = {}
|
81 |
+
info_d = {}
|
82 |
+
while self.env.agents:
|
83 |
+
obs, rew, terminated, truncated, info = self.env.last()
|
84 |
+
agent_id = self.env.agent_selection
|
85 |
+
obs_d[agent_id] = obs
|
86 |
+
rew_d[agent_id] = rew
|
87 |
+
terminated_d[agent_id] = terminated
|
88 |
+
truncated_d[agent_id] = truncated
|
89 |
+
info_d[agent_id] = info
|
90 |
+
if (
|
91 |
+
self.env.terminations[self.env.agent_selection]
|
92 |
+
or self.env.truncations[self.env.agent_selection]
|
93 |
+
):
|
94 |
+
self.env.step(None)
|
95 |
+
else:
|
96 |
+
break
|
97 |
+
|
98 |
+
all_gone = not self.env.agents
|
99 |
+
terminated_d["__all__"] = all_gone and all(terminated_d.values())
|
100 |
+
truncated_d["__all__"] = all_gone and all(truncated_d.values())
|
101 |
+
|
102 |
+
return obs_d, rew_d, terminated_d, truncated_d, info_d
|
103 |
+
|
104 |
+
def close(self):
|
105 |
+
self.env.close()
|
106 |
+
|
107 |
+
def render(self):
|
108 |
+
return self.env.render()
|
109 |
+
|
110 |
+
@property
|
111 |
+
def get_sub_environments(self):
|
112 |
+
return self.env.unwrapped
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://python-poetry.org/docs/pyproject/
|
2 |
+
[tool.poetry]
|
3 |
+
name = "connectfour"
|
4 |
+
version = "0.1.0"
|
5 |
+
description = "Connect Four"
|
6 |
+
authors = ["Clément Brutti-Mairesse <clement.brutti.mairesse@gmail.com>"]
|
7 |
+
license = "MIT"
|
8 |
+
readme = "README.md"
|
9 |
+
homepage = "https://huggingface.co/spaces/ClementBM/connectfour"
|
10 |
+
repository = "https://huggingface.co/spaces/ClementBM/connectfour"
|
11 |
+
keywords = ["connectfour", "connect4", "reinforcement learning"]
|
12 |
+
include = [
|
13 |
+
"LICENSE",
|
14 |
+
]
|
15 |
+
|
16 |
+
[tool.poetry.dependencies]
|
17 |
+
python = ">=3.8,<3.11"
|
18 |
+
orjson = "3.8.8"
|
19 |
+
gradio = "^3.23.0"
|
20 |
+
ray = {extras = ["rllib", "serve"], version = "^2.2.0"}
|
21 |
+
pettingzoo = "^1.22.4"
|
22 |
+
pygame = "^2.3.0"
|
23 |
+
torch = "^2.0.0"
|
24 |
+
libclang = "15.0.6.1"
|
25 |
+
tensorflow-probability = "^0.19.0"
|
26 |
+
protobuf = "3.17.0"
|
27 |
+
scipy = ">=1.8,<1.9.2"
|
28 |
+
|
29 |
+
[tool.poetry.dev-dependencies]
|
30 |
+
pylint = "*"
|
31 |
+
pytest = "*"
|
32 |
+
mypy = "*"
|
33 |
+
black = "*"
|
34 |
+
|
35 |
+
[build-system]
|
36 |
+
requires = ["poetry-core>=1.0.0"]
|
37 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11"
|
2 |
+
aiofiles==23.1.0 ; python_version >= "3.8" and python_version < "3.11"
|
3 |
+
aiohttp-cors==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
|
4 |
+
aiohttp==3.8.4 ; python_version >= "3.8" and python_version < "3.11"
|
5 |
+
aiorwlock==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
6 |
+
aiosignal==1.3.1 ; python_version >= "3.8" and python_version < "3.11"
|
7 |
+
altair==4.2.2 ; python_version >= "3.8" and python_version < "3.11"
|
8 |
+
ansicon==1.89.0 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
|
9 |
+
anyio==3.6.2 ; python_version >= "3.8" and python_version < "3.11"
|
10 |
+
async-timeout==4.0.2 ; python_version >= "3.8" and python_version < "3.11"
|
11 |
+
attrs==22.2.0 ; python_version >= "3.8" and python_version < "3.11"
|
12 |
+
blessed==1.20.0 ; python_version >= "3.8" and python_version < "3.11"
|
13 |
+
cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
14 |
+
certifi==2022.12.7 ; python_version >= "3.8" and python_version < "3.11"
|
15 |
+
charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11"
|
16 |
+
click==8.1.3 ; python_version >= "3.8" and python_version < "3.11"
|
17 |
+
cloudpickle==2.2.1 ; python_version >= "3.8" and python_version < "3.11"
|
18 |
+
cmake==3.26.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
19 |
+
colorama==0.4.6 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
|
20 |
+
colorful==0.5.5 ; python_version >= "3.8" and python_version < "3.11"
|
21 |
+
contourpy==1.0.7 ; python_version >= "3.8" and python_version < "3.11"
|
22 |
+
cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11"
|
23 |
+
decorator==5.1.1 ; python_version >= "3.8" and python_version < "3.11"
|
24 |
+
distlib==0.3.6 ; python_version >= "3.8" and python_version < "3.11"
|
25 |
+
dm-tree==0.1.8 ; python_version >= "3.8" and python_version < "3.11"
|
26 |
+
entrypoints==0.4 ; python_version >= "3.8" and python_version < "3.11"
|
27 |
+
fastapi==0.95.0 ; python_version >= "3.8" and python_version < "3.11"
|
28 |
+
ffmpy==0.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
29 |
+
filelock==3.10.7 ; python_version >= "3.8" and python_version < "3.11"
|
30 |
+
fonttools==4.39.3 ; python_version >= "3.8" and python_version < "3.11"
|
31 |
+
frozenlist==1.3.3 ; python_version >= "3.8" and python_version < "3.11"
|
32 |
+
fsspec==2023.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
33 |
+
gast==0.5.3 ; python_version >= "3.8" and python_version < "3.11"
|
34 |
+
google-api-core==2.8.2 ; python_version >= "3.8" and python_version < "3.11"
|
35 |
+
google-auth==2.17.0 ; python_version >= "3.8" and python_version < "3.11"
|
36 |
+
googleapis-common-protos==1.56.4 ; python_version >= "3.8" and python_version < "3.11"
|
37 |
+
gpustat==1.0.0 ; python_version >= "3.8" and python_version < "3.11"
|
38 |
+
gradio==3.23.0 ; python_version >= "3.8" and python_version < "3.11"
|
39 |
+
grpcio==1.49.1 ; python_version >= "3.8" and python_version < "3.11" and sys_platform == "darwin"
|
40 |
+
grpcio==1.53.0 ; python_version >= "3.8" and python_version < "3.11" and sys_platform != "darwin"
|
41 |
+
gymnasium-notices==0.0.1 ; python_version >= "3.8" and python_version < "3.11"
|
42 |
+
gymnasium==0.26.3 ; python_version >= "3.8" and python_version < "3.11"
|
43 |
+
h11==0.14.0 ; python_version >= "3.8" and python_version < "3.11"
|
44 |
+
httpcore==0.16.3 ; python_version >= "3.8" and python_version < "3.11"
|
45 |
+
httpx==0.23.3 ; python_version >= "3.8" and python_version < "3.11"
|
46 |
+
huggingface-hub==0.13.3 ; python_version >= "3.8" and python_version < "3.11"
|
47 |
+
idna==3.4 ; python_version >= "3.8" and python_version < "3.11"
|
48 |
+
imageio==2.27.0 ; python_version >= "3.8" and python_version < "3.11"
|
49 |
+
importlib-metadata==6.1.0 ; python_version >= "3.8" and python_version < "3.10"
|
50 |
+
importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.10"
|
51 |
+
jinja2==3.1.2 ; python_version >= "3.8" and python_version < "3.11"
|
52 |
+
jinxed==1.2.0 ; python_version >= "3.8" and python_version < "3.11" and platform_system == "Windows"
|
53 |
+
jsonschema==4.17.3 ; python_version >= "3.8" and python_version < "3.11"
|
54 |
+
kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11"
|
55 |
+
lazy-loader==0.2 ; python_version >= "3.8" and python_version < "3.11"
|
56 |
+
libclang==15.0.6.1 ; python_version >= "3.8" and python_version < "3.11"
|
57 |
+
linkify-it-py==2.0.0 ; python_version >= "3.8" and python_version < "3.11"
|
58 |
+
lit==16.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
59 |
+
lz4==4.3.2 ; python_version >= "3.8" and python_version < "3.11"
|
60 |
+
markdown-it-py==2.2.0 ; python_version >= "3.8" and python_version < "3.11"
|
61 |
+
markdown-it-py[linkify]==2.2.0 ; python_version >= "3.8" and python_version < "3.11"
|
62 |
+
markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11"
|
63 |
+
matplotlib==3.7.1 ; python_version >= "3.8" and python_version < "3.11"
|
64 |
+
mdit-py-plugins==0.3.3 ; python_version >= "3.8" and python_version < "3.11"
|
65 |
+
mdurl==0.1.2 ; python_version >= "3.8" and python_version < "3.11"
|
66 |
+
mpmath==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
67 |
+
msgpack==1.0.5 ; python_version >= "3.8" and python_version < "3.11"
|
68 |
+
multidict==6.0.4 ; python_version >= "3.8" and python_version < "3.11"
|
69 |
+
networkx==3.0 ; python_version >= "3.8" and python_version < "3.11"
|
70 |
+
numpy==1.24.2 ; python_version < "3.11" and python_version >= "3.8"
|
71 |
+
nvidia-cublas-cu11==11.10.3.66 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
72 |
+
nvidia-cuda-cupti-cu11==11.7.101 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
73 |
+
nvidia-cuda-nvrtc-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
74 |
+
nvidia-cuda-runtime-cu11==11.7.99 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
75 |
+
nvidia-cudnn-cu11==8.5.0.96 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
76 |
+
nvidia-cufft-cu11==10.9.0.58 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
77 |
+
nvidia-curand-cu11==10.2.10.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
78 |
+
nvidia-cusolver-cu11==11.4.0.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
79 |
+
nvidia-cusparse-cu11==11.7.4.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
80 |
+
nvidia-ml-py==11.495.46 ; python_version >= "3.8" and python_version < "3.11"
|
81 |
+
nvidia-nccl-cu11==2.14.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
82 |
+
nvidia-nvtx-cu11==11.7.91 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
83 |
+
opencensus-context==0.1.3 ; python_version >= "3.8" and python_version < "3.11"
|
84 |
+
opencensus==0.11.2 ; python_version >= "3.8" and python_version < "3.11"
|
85 |
+
orjson==3.8.8 ; python_version >= "3.8" and python_version < "3.11"
|
86 |
+
packaging==23.0 ; python_version < "3.11" and python_version >= "3.8"
|
87 |
+
pandas==1.5.3 ; python_version >= "3.8" and python_version < "3.11"
|
88 |
+
pettingzoo==1.22.4 ; python_version >= "3.8" and python_version < "3.11"
|
89 |
+
pillow==9.4.0 ; python_version >= "3.8" and python_version < "3.11"
|
90 |
+
pkgutil-resolve-name==1.3.10 ; python_version >= "3.8" and python_version < "3.9"
|
91 |
+
platformdirs==3.2.0 ; python_version >= "3.8" and python_version < "3.11"
|
92 |
+
prometheus-client==0.16.0 ; python_version >= "3.8" and python_version < "3.11"
|
93 |
+
protobuf==3.17.0 ; python_version >= "3.8" and python_version < "3.11"
|
94 |
+
psutil==5.9.4 ; python_version >= "3.8" and python_version < "3.11"
|
95 |
+
py-spy==0.3.14 ; python_version >= "3.8" and python_version < "3.11"
|
96 |
+
pyasn1-modules==0.2.8 ; python_version >= "3.8" and python_version < "3.11"
|
97 |
+
pyasn1==0.4.8 ; python_version >= "3.8" and python_version < "3.11"
|
98 |
+
pydantic==1.10.7 ; python_version >= "3.8" and python_version < "3.11"
|
99 |
+
pydub==0.25.1 ; python_version >= "3.8" and python_version < "3.11"
|
100 |
+
pygame==2.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
101 |
+
pygments==2.14.0 ; python_version >= "3.8" and python_version < "3.11"
|
102 |
+
pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11"
|
103 |
+
pyrsistent==0.19.3 ; python_version >= "3.8" and python_version < "3.11"
|
104 |
+
python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11"
|
105 |
+
python-multipart==0.0.6 ; python_version >= "3.8" and python_version < "3.11"
|
106 |
+
pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11"
|
107 |
+
pywavelets==1.4.1 ; python_version >= "3.8" and python_version < "3.11"
|
108 |
+
pyyaml==6.0 ; python_version >= "3.8" and python_version < "3.11"
|
109 |
+
ray[rllib,serve]==2.3.1 ; python_version >= "3.8" and python_version < "3.11"
|
110 |
+
requests==2.28.2 ; python_version >= "3.8" and python_version < "3.11"
|
111 |
+
rfc3986[idna2008]==1.5.0 ; python_version >= "3.8" and python_version < "3.11"
|
112 |
+
rich==13.3.3 ; python_version >= "3.8" and python_version < "3.11"
|
113 |
+
rsa==4.9 ; python_version >= "3.8" and python_version < "3.11"
|
114 |
+
scikit-image==0.20.0 ; python_version >= "3.8" and python_version < "3.11"
|
115 |
+
scipy==1.9.1 ; python_version < "3.11" and python_version >= "3.8"
|
116 |
+
semantic-version==2.10.0 ; python_version >= "3.8" and python_version < "3.11"
|
117 |
+
setuptools==67.6.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
118 |
+
six==1.16.0 ; python_version < "3.11" and python_version >= "3.8"
|
119 |
+
smart-open==6.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
120 |
+
sniffio==1.3.0 ; python_version >= "3.8" and python_version < "3.11"
|
121 |
+
starlette==0.26.1 ; python_version >= "3.8" and python_version < "3.11"
|
122 |
+
sympy==1.11.1 ; python_version >= "3.8" and python_version < "3.11"
|
123 |
+
tabulate==0.9.0 ; python_version >= "3.8" and python_version < "3.11"
|
124 |
+
tensorboardx==2.6 ; python_version >= "3.8" and python_version < "3.11"
|
125 |
+
tensorflow-probability==0.19.0 ; python_version >= "3.8" and python_version < "3.11"
|
126 |
+
tifffile==2023.3.21 ; python_version >= "3.8" and python_version < "3.11"
|
127 |
+
toolz==0.12.0 ; python_version >= "3.8" and python_version < "3.11"
|
128 |
+
torch==2.0.0 ; python_version >= "3.8" and python_version < "3.11"
|
129 |
+
tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11"
|
130 |
+
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
131 |
+
typer==0.7.0 ; python_version >= "3.8" and python_version < "3.11"
|
132 |
+
typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11"
|
133 |
+
uc-micro-py==1.0.1 ; python_version >= "3.8" and python_version < "3.11"
|
134 |
+
urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11"
|
135 |
+
uvicorn==0.21.1 ; python_version >= "3.8" and python_version < "3.11"
|
136 |
+
virtualenv==20.21.0 ; python_version >= "3.8" and python_version < "3.11"
|
137 |
+
wcwidth==0.2.6 ; python_version >= "3.8" and python_version < "3.11"
|
138 |
+
websockets==10.4 ; python_version >= "3.8" and python_version < "3.11"
|
139 |
+
wheel==0.40.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.11"
|
140 |
+
yarl==1.8.2 ; python_version >= "3.8" and python_version < "3.11"
|
141 |
+
zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10"
|