Spaces:
Runtime error
Runtime error
add error screen
Browse files- connectfour/app.py +19 -1
- connectfour/training/train.py +11 -3
- error-screen.npy +3 -0
connectfour/app.py
CHANGED
@@ -14,6 +14,7 @@ from connectfour.training.wrappers import Connect4Env
|
|
14 |
POLICY_ID = "learned_v5"
|
15 |
|
16 |
# poetry export -f requirements.txt --output requirements.txt --without-hashes
|
|
|
17 |
|
18 |
|
19 |
class Connect4:
|
@@ -55,6 +56,9 @@ class Connect4:
|
|
55 |
self.algo.restore(checkpoint)
|
56 |
|
57 |
def play(self, action=None):
|
|
|
|
|
|
|
58 |
if self.human != self.player_id:
|
59 |
action = self.algo.compute_single_action(
|
60 |
self.obs[self.player_id], policy_id=POLICY_ID
|
@@ -73,8 +77,11 @@ class Connect4:
|
|
73 |
|
74 |
@property
|
75 |
def render_and_state(self):
|
76 |
-
end_message = "End of the game"
|
77 |
if self.done:
|
|
|
|
|
|
|
|
|
78 |
if self.reward[self.human] > 0:
|
79 |
end_message += ": You WIN !!"
|
80 |
elif self.reward[self.human] < 0:
|
@@ -83,6 +90,12 @@ class Connect4:
|
|
83 |
|
84 |
return self.env.render(), "Game On"
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
@property
|
87 |
def player_id(self):
|
88 |
return list(self.obs.keys())[0]
|
@@ -91,6 +104,11 @@ class Connect4:
|
|
91 |
def legal_moves(self):
|
92 |
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
demo = gr.Blocks()
|
96 |
|
|
|
14 |
POLICY_ID = "learned_v5"
|
15 |
|
16 |
# poetry export -f requirements.txt --output requirements.txt --without-hashes
|
17 |
+
# gradio connectfour/app.py
|
18 |
|
19 |
|
20 |
class Connect4:
|
|
|
56 |
self.algo.restore(checkpoint)
|
57 |
|
58 |
def play(self, action=None):
|
59 |
+
if self.has_erroneous_state():
|
60 |
+
return self.blue_screen()
|
61 |
+
|
62 |
if self.human != self.player_id:
|
63 |
action = self.algo.compute_single_action(
|
64 |
self.obs[self.player_id], policy_id=POLICY_ID
|
|
|
77 |
|
78 |
@property
|
79 |
def render_and_state(self):
|
|
|
80 |
if self.done:
|
81 |
+
if hasattr(self, "reward") and self.human not in self.reward:
|
82 |
+
return self.blue_screen()
|
83 |
+
|
84 |
+
end_message = "End of the game"
|
85 |
if self.reward[self.human] > 0:
|
86 |
end_message += ": You WIN !!"
|
87 |
elif self.reward[self.human] < 0:
|
|
|
90 |
|
91 |
return self.env.render(), "Game On"
|
92 |
|
93 |
+
def blue_screen(self):
|
94 |
+
with open("error-screen.npy", "rb") as f:
|
95 |
+
error_screen = np.load(f)
|
96 |
+
|
97 |
+
return (error_screen, "Restart the Game")
|
98 |
+
|
99 |
@property
|
100 |
def player_id(self):
|
101 |
return list(self.obs.keys())[0]
|
|
|
104 |
def legal_moves(self):
|
105 |
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
|
106 |
|
107 |
+
def has_erroneous_state(self):
|
108 |
+
if len(list(self.obs.keys())) == 0:
|
109 |
+
return True
|
110 |
+
return False
|
111 |
+
|
112 |
|
113 |
demo = gr.Blocks()
|
114 |
|
connectfour/training/train.py
CHANGED
@@ -27,7 +27,7 @@ def get_cli_args():
|
|
27 |
Create CLI parser and return parsed arguments
|
28 |
|
29 |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
30 |
-
python connectfour/training/train.py --num-gpus 1 --stop-iters
|
31 |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
32 |
"""
|
33 |
parser = argparse.ArgumentParser()
|
@@ -68,7 +68,10 @@ if __name__ == "__main__":
|
|
68 |
args = get_cli_args()
|
69 |
|
70 |
ray.init(
|
71 |
-
num_cpus=args.num_cpus or None,
|
|
|
|
|
|
|
72 |
)
|
73 |
|
74 |
# define how to make the environment
|
@@ -126,6 +129,8 @@ if __name__ == "__main__":
|
|
126 |
"win_rate": "win_rate",
|
127 |
"league_size": "league_size",
|
128 |
},
|
|
|
|
|
129 |
sort_by_metric=True,
|
130 |
),
|
131 |
checkpoint_config=air.CheckpointConfig(
|
@@ -135,6 +140,9 @@ if __name__ == "__main__":
|
|
135 |
),
|
136 |
).fit()
|
137 |
|
138 |
-
print(
|
|
|
|
|
|
|
139 |
|
140 |
ray.shutdown()
|
|
|
27 |
Create CLI parser and return parsed arguments
|
28 |
|
29 |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
30 |
+
python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
|
31 |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
32 |
"""
|
33 |
parser = argparse.ArgumentParser()
|
|
|
68 |
args = get_cli_args()
|
69 |
|
70 |
ray.init(
|
71 |
+
num_cpus=args.num_cpus or None,
|
72 |
+
num_gpus=args.num_gpus,
|
73 |
+
include_dashboard=False,
|
74 |
+
resources={"accelerator_type:RTX": 1},
|
75 |
)
|
76 |
|
77 |
# define how to make the environment
|
|
|
129 |
"win_rate": "win_rate",
|
130 |
"league_size": "league_size",
|
131 |
},
|
132 |
+
mode="max",
|
133 |
+
metric="win_rate",
|
134 |
sort_by_metric=True,
|
135 |
),
|
136 |
checkpoint_config=air.CheckpointConfig(
|
|
|
140 |
),
|
141 |
).fit()
|
142 |
|
143 |
+
print(
|
144 |
+
"Best checkpoint",
|
145 |
+
results.get_best_result(metric="win_rate", mode="max").checkpoint,
|
146 |
+
)
|
147 |
|
148 |
ray.shutdown()
|
error-screen.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b72c5a148f41583927cd127d1d2b51073adec2ebd33ace7d4c074142d16d992
|
3 |
+
size 4316726
|