ClementBM commited on
Commit
9ca3472
1 Parent(s): f7cc099

make the game multi sessions

Browse files
Files changed (2) hide show
  1. connectfour/app.py +126 -37
  2. connectfour/connect4.py +0 -92
connectfour/app.py CHANGED
@@ -1,17 +1,94 @@
1
  import time
2
 
3
  import gradio as gr
 
 
 
4
 
5
- from connectfour.connect4 import Connect4
 
6
 
7
  # poetry export -f requirements.txt --output requirements.txt --without-hashes
8
  # gradio connectfour/app.py
9
 
10
-
11
  demo = gr.Blocks()
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  with demo:
14
- connect4 = Connect4("You")
 
 
15
  drop_token_btns = []
16
 
17
  with gr.Row():
@@ -22,70 +99,83 @@ with demo:
22
  label="Who plays first", choices=["You", "Bot"], value="You"
23
  )
24
  reinitialize = gr.Button("New Game")
25
-
26
- game_state = gr.Text(value="Game On", interactive=False, label="Status")
27
 
28
  with gr.Column(scale=1):
29
  output = gr.Image(
30
  label="Connect Four Grid",
31
  type="numpy",
32
  show_label=False,
33
- value=connect4.env.render(),
34
  )
35
 
36
  with gr.Row():
37
- for i in range(7):
38
  with gr.Column(min_width=20):
39
  drop_token_btns.append(gr.Button("X", elem_id=i))
40
 
41
- who_plays_first.change(
42
- connect4.init_env, who_plays_first, outputs=[output, game_state]
43
- )
44
-
45
- def reinit_game(who_plays_first):
46
- output, game_state = connect4.init_env(who_plays_first)
47
- return output, game_state, gr.Checkbox.update(interactive=True)
 
48
 
49
- reinitialize.click(
50
- reinit_game, who_plays_first, outputs=[output, game_state, who_plays_first]
51
- )
52
 
53
  def wait(game_state_value):
54
- if game_state_value == "Game On":
55
  time.sleep(0.7)
56
  return gr.Checkbox.update(interactive=False)
57
  else:
58
  return gr.Checkbox.update(interactive=True)
59
 
60
- def bot(game_state_value):
61
- if game_state_value == "Game On":
62
- rendered_env = connect4.play()
63
- return *rendered_env, gr.Checkbox.update(interactive=False) if rendered_env[
64
- 1
65
- ] == "Game On" else gr.Checkbox.update(interactive=True)
 
66
  return (
67
- gr.Image.update(),
68
  game_state_value,
69
  gr.Checkbox.update(interactive=True),
70
  )
71
 
72
- def click_column(evt: gr.EventData):
73
- output, game_state = connect4.play(int(evt.target.elem_id))
74
- return output, game_state
75
 
76
  def game_state_change(value):
77
- if value == "Game On":
78
- return [gr.Button.update(interactive=True)] * 7
79
- else:
80
- return [gr.Button.update(interactive=False)] * 7
 
 
 
81
 
82
- for i in range(7):
 
 
 
 
 
 
83
  drop_token_btns[i].click(
84
  click_column,
85
- outputs=[output, game_state],
 
 
 
 
 
86
  ).then(
87
- wait, inputs=[game_state], outputs=who_plays_first
88
- ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
89
 
90
  game_state.change(
91
  game_state_change,
@@ -93,5 +183,4 @@ with demo:
93
  outputs=drop_token_btns,
94
  )
95
 
96
-
97
  demo.launch()
 
1
  import time
2
 
3
  import gradio as gr
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from pettingzoo.classic import connect_four_v3
7
 
8
+ from connectfour import ERROR_SCREEN
9
+ from models import MODEL_PATH
10
 
11
  # poetry export -f requirements.txt --output requirements.txt --without-hashes
12
  # gradio connectfour/app.py
13
 
14
+ session = ort.InferenceSession(str(MODEL_PATH), None)
15
  demo = gr.Blocks()
16
 
17
+ column_count = 7
18
+ game_on_msg = "Game On"
19
+
20
+
21
+ def flatten_observation(obs):
22
+ flatten_action_mask = np.array(obs["action_mask"])
23
+ flatten_observation = np.reshape(obs["observation"], 2 * 6 * column_count)
24
+ flatten_obs = np.concatenate([flatten_action_mask, flatten_observation])
25
+ return flatten_obs[np.newaxis, ...].astype(np.float32)
26
+
27
+
28
+ def legal_moves(env, player_id):
29
+ return np.arange(column_count)[env.observe(player_id)["action_mask"] == 1]
30
+
31
+
32
+ def done(env):
33
+ return np.any(list(env.terminations.values()) + list(env.truncations.values()))
34
+
35
+
36
+ def get_state_msg(env, human):
37
+ if done(env):
38
+ end_message = "End of the game"
39
+ if env.rewards[human] > 0:
40
+ end_message += ": You WIN !!"
41
+ elif env.rewards[human] < 0:
42
+ end_message += ": You LOSE !!"
43
+ return end_message
44
+
45
+ return game_on_msg
46
+
47
+
48
+ def play(env, human, action=None):
49
+ try:
50
+ if human != env.agent_selection:
51
+ action = session.run(
52
+ ["output"],
53
+ {
54
+ "obs": flatten_observation(env.observe(env.agent_selection)),
55
+ "state_ins": [],
56
+ },
57
+ )
58
+ action = int(np.argmax(action[0]))
59
+
60
+ if action not in legal_moves(env, env.agent_selection):
61
+ action = np.random.choice(legal_moves(env, env.agent_selection))
62
+
63
+ env.step(action)
64
+ return env, get_state_msg(env, human)
65
+ except Exception as e:
66
+ return env, f"Restart the Game"
67
+
68
+
69
+ def init_env(env, who_plays_first, human):
70
+ env.reset()
71
+
72
+ if who_plays_first != "You":
73
+ play(env, human)
74
+
75
+ return env
76
+
77
+
78
+ def error_screen():
79
+ with open(ERROR_SCREEN, "rb") as f:
80
+ error_screen = np.load(f)
81
+ return error_screen
82
+
83
+
84
+ def create_env():
85
+ return init_env(connect_four_v3.env(render_mode="rgb_array"), "You", "player_0")
86
+
87
+
88
  with demo:
89
+ human = gr.State("player_0")
90
+ env = gr.State(create_env())
91
+
92
  drop_token_btns = []
93
 
94
  with gr.Row():
 
99
  label="Who plays first", choices=["You", "Bot"], value="You"
100
  )
101
  reinitialize = gr.Button("New Game")
102
+ game_state = gr.Text(value=game_on_msg, interactive=False, label="Status")
 
103
 
104
  with gr.Column(scale=1):
105
  output = gr.Image(
106
  label="Connect Four Grid",
107
  type="numpy",
108
  show_label=False,
109
+ value=error_screen(),
110
  )
111
 
112
  with gr.Row():
113
+ for i in range(column_count):
114
  with gr.Column(min_width=20):
115
  drop_token_btns.append(gr.Button("X", elem_id=i))
116
 
117
+ def reinit_game(env, who_plays_first, human):
118
+ env = init_env(env, who_plays_first, human)
119
+ return [
120
+ env,
121
+ env.agent_selection, # human
122
+ get_state_msg(env, human), # state_msg
123
+ gr.Checkbox.update(interactive=True), # who_plays_first
124
+ ]
125
 
126
+ def on_render_change(env):
127
+ return env.render()
 
128
 
129
  def wait(game_state_value):
130
+ if game_state_value == game_on_msg:
131
  time.sleep(0.7)
132
  return gr.Checkbox.update(interactive=False)
133
  else:
134
  return gr.Checkbox.update(interactive=True)
135
 
136
+ def bot(env, game_state_value):
137
+ if game_state_value == game_on_msg:
138
+ env, state_msg = play(env, human)
139
+ if state_msg == game_on_msg:
140
+ return state_msg, gr.Checkbox.update(interactive=False)
141
+ else:
142
+ return state_msg, gr.Checkbox.update(interactive=True)
143
  return (
 
144
  game_state_value,
145
  gr.Checkbox.update(interactive=True),
146
  )
147
 
148
+ def click_column(env, human, evt: gr.EventData):
149
+ env, state_msg = play(env, human, int(evt.target.elem_id))
150
+ return env, state_msg
151
 
152
  def game_state_change(value):
153
+ return [gr.Button.update(interactive=value == game_on_msg)] * column_count
154
+
155
+ who_plays_first.change(
156
+ reinit_game,
157
+ [env, who_plays_first, human],
158
+ outputs=[env, human, game_state, who_plays_first],
159
+ ).then(on_render_change, inputs=[env], outputs=[output])
160
 
161
+ reinitialize.click(
162
+ reinit_game,
163
+ [env, who_plays_first, human],
164
+ outputs=[env, human, game_state, who_plays_first],
165
+ ).then(on_render_change, inputs=[env], outputs=[output])
166
+
167
+ for i in range(column_count):
168
  drop_token_btns[i].click(
169
  click_column,
170
+ inputs=[env, human],
171
+ outputs=[env, game_state],
172
+ ).then(on_render_change, inputs=[env], outputs=[output]).then(
173
+ wait, inputs=[game_state], outputs=[who_plays_first]
174
+ ).then(
175
+ bot, inputs=[env, game_state], outputs=[game_state, who_plays_first]
176
  ).then(
177
+ on_render_change, inputs=[env], outputs=[output]
178
+ )
179
 
180
  game_state.change(
181
  game_state_change,
 
183
  outputs=drop_token_btns,
184
  )
185
 
 
186
  demo.launch()
connectfour/connect4.py DELETED
@@ -1,92 +0,0 @@
1
- import numpy as np
2
- import onnxruntime as ort
3
- from pettingzoo.classic import connect_four_v3
4
- from connectfour import ERROR_SCREEN
5
-
6
- from models import MODEL_PATH
7
-
8
-
9
- class Connect4:
10
- def __init__(self, who_plays_first) -> None:
11
- self.init_env(who_plays_first)
12
- self.session = ort.InferenceSession(str(MODEL_PATH), None)
13
-
14
- def init_env(self, who_plays_first):
15
- self.env = connect_four_v3.env(render_mode="rgb_array")
16
- self.env.reset()
17
-
18
- if who_plays_first == "You":
19
- self.human = self.current_player_id
20
- else:
21
- self.play()
22
- self.human = self.current_player_id
23
-
24
- return self.render_and_state
25
-
26
- def flatten_observation(self, obs):
27
- flatten_action_mask = np.array(obs["action_mask"])
28
- flatten_observation = np.reshape(obs["observation"], 2 * 6 * 7)
29
- flatten_obs = np.concatenate([flatten_action_mask, flatten_observation])
30
- return flatten_obs[np.newaxis, ...].astype(np.float32)
31
-
32
- def play(self, action=None):
33
- try:
34
- if self.human != self.current_player_id:
35
- action = self.session.run(
36
- ["output"],
37
- {
38
- "obs": self.flatten_observation(
39
- self.env.observe(self.current_player_id)
40
- ),
41
- "state_ins": [],
42
- },
43
- )
44
- action = int(np.argmax(action[0]))
45
-
46
- if action not in self.legal_moves:
47
- action = np.random.choice(self.legal_moves)
48
-
49
- self.env.step(action)
50
-
51
- return self.render_and_state
52
- except:
53
- return self.blue_screen()
54
-
55
- @property
56
- def current_player_id(self):
57
- return self.env.agent_selection
58
-
59
- @property
60
- def current_observation(self):
61
- return self.env.observe(self.current_player_id)
62
-
63
- @property
64
- def legal_moves(self):
65
- return np.arange(7)[self.current_observation["action_mask"] == 1]
66
-
67
- @property
68
- def done(self):
69
- return np.any(
70
- list(self.env.terminations.values()) + list(self.env.truncations.values())
71
- )
72
-
73
- @property
74
- def render_and_state(self):
75
- if self.done:
76
- if self.human not in self.env.rewards:
77
- return self.blue_screen()
78
-
79
- end_message = "End of the game"
80
- if self.env.rewards[self.human] > 0:
81
- end_message += ": You WIN !!"
82
- elif self.env.rewards[self.human] < 0:
83
- end_message += ": You LOSE !!"
84
- return self.env.render(), end_message
85
-
86
- return self.env.render(), "Game On"
87
-
88
- def blue_screen(self):
89
- with open(ERROR_SCREEN, "rb") as f:
90
- error_screen = np.load(f)
91
-
92
- return (error_screen, "Restart the Game")