ClementBM commited on
Commit
7db569a
1 Parent(s): d1757d4

add error screen

Browse files
connectfour/app.py CHANGED
@@ -14,6 +14,7 @@ from connectfour.training.wrappers import Connect4Env
14
  POLICY_ID = "learned_v5"
15
 
16
  # poetry export -f requirements.txt --output requirements.txt --without-hashes
 
17
 
18
 
19
  class Connect4:
@@ -55,6 +56,9 @@ class Connect4:
55
  self.algo.restore(checkpoint)
56
 
57
  def play(self, action=None):
 
 
 
58
  if self.human != self.player_id:
59
  action = self.algo.compute_single_action(
60
  self.obs[self.player_id], policy_id=POLICY_ID
@@ -73,8 +77,11 @@ class Connect4:
73
 
74
  @property
75
  def render_and_state(self):
76
- end_message = "End of the game"
77
  if self.done:
 
 
 
 
78
  if self.reward[self.human] > 0:
79
  end_message += ": You WIN !!"
80
  elif self.reward[self.human] < 0:
@@ -83,6 +90,12 @@ class Connect4:
83
 
84
  return self.env.render(), "Game On"
85
 
 
 
 
 
 
 
86
  @property
87
  def player_id(self):
88
  return list(self.obs.keys())[0]
@@ -91,6 +104,11 @@ class Connect4:
91
  def legal_moves(self):
92
  return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
93
 
 
 
 
 
 
94
 
95
  demo = gr.Blocks()
96
 
 
14
  POLICY_ID = "learned_v5"
15
 
16
  # poetry export -f requirements.txt --output requirements.txt --without-hashes
17
+ # gradio connectfour/app.py
18
 
19
 
20
  class Connect4:
 
56
  self.algo.restore(checkpoint)
57
 
58
  def play(self, action=None):
59
+ if self.has_erroneous_state():
60
+ return self.blue_screen()
61
+
62
  if self.human != self.player_id:
63
  action = self.algo.compute_single_action(
64
  self.obs[self.player_id], policy_id=POLICY_ID
 
77
 
78
  @property
79
  def render_and_state(self):
 
80
  if self.done:
81
+ if hasattr(self, "reward") and self.human not in self.reward:
82
+ return self.blue_screen()
83
+
84
+ end_message = "End of the game"
85
  if self.reward[self.human] > 0:
86
  end_message += ": You WIN !!"
87
  elif self.reward[self.human] < 0:
 
90
 
91
  return self.env.render(), "Game On"
92
 
93
+ def blue_screen(self):
94
+ with open("error-screen.npy", "rb") as f:
95
+ error_screen = np.load(f)
96
+
97
+ return (error_screen, "Restart the Game")
98
+
99
  @property
100
  def player_id(self):
101
  return list(self.obs.keys())[0]
 
104
  def legal_moves(self):
105
  return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
106
 
107
+ def has_erroneous_state(self):
108
+ if len(list(self.obs.keys())) == 0:
109
+ return True
110
+ return False
111
+
112
 
113
  demo = gr.Blocks()
114
 
connectfour/training/train.py CHANGED
@@ -27,7 +27,7 @@ def get_cli_args():
27
  Create CLI parser and return parsed arguments
28
 
29
  python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
30
- python connectfour/training/train.py --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
31
  python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
32
  """
33
  parser = argparse.ArgumentParser()
@@ -68,7 +68,10 @@ if __name__ == "__main__":
68
  args = get_cli_args()
69
 
70
  ray.init(
71
- num_cpus=args.num_cpus or None, num_gpus=args.num_gpus, include_dashboard=False
 
 
 
72
  )
73
 
74
  # define how to make the environment
@@ -126,6 +129,8 @@ if __name__ == "__main__":
126
  "win_rate": "win_rate",
127
  "league_size": "league_size",
128
  },
 
 
129
  sort_by_metric=True,
130
  ),
131
  checkpoint_config=air.CheckpointConfig(
@@ -135,6 +140,9 @@ if __name__ == "__main__":
135
  ),
136
  ).fit()
137
 
138
- print("Best checkpoint", results.get_best_result().checkpoint)
 
 
 
139
 
140
  ray.shutdown()
 
27
  Create CLI parser and return parsed arguments
28
 
29
  python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
30
+ python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
31
  python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
32
  """
33
  parser = argparse.ArgumentParser()
 
68
  args = get_cli_args()
69
 
70
  ray.init(
71
+ num_cpus=args.num_cpus or None,
72
+ num_gpus=args.num_gpus,
73
+ include_dashboard=False,
74
+ resources={"accelerator_type:RTX": 1},
75
  )
76
 
77
  # define how to make the environment
 
129
  "win_rate": "win_rate",
130
  "league_size": "league_size",
131
  },
132
+ mode="max",
133
+ metric="win_rate",
134
  sort_by_metric=True,
135
  ),
136
  checkpoint_config=air.CheckpointConfig(
 
140
  ),
141
  ).fit()
142
 
143
+ print(
144
+ "Best checkpoint",
145
+ results.get_best_result(metric="win_rate", mode="max").checkpoint,
146
+ )
147
 
148
  ray.shutdown()
error-screen.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b72c5a148f41583927cd127d1d2b51073adec2ebd33ace7d4c074142d16d992
3
+ size 4316726