Spaces:
Runtime error
Runtime error
File size: 5,491 Bytes
cd63926 ffe7549 9ca3472 ffe7549 9ca3472 ffe7549 d1757d4 7db569a d1757d4 9ca3472 d1757d4 9ca3472 ffe7549 9ca3472 ef91ec6 ffe7549 9ca3472 ffe7549 9ca3472 ffe7549 9ca3472 ef91ec6 ffe7549 9ca3472 ffe7549 9ca3472 ffe7549 9ca3472 ef91ec6 ffe7549 a9fd1f2 9ca3472 ffe7549 9ca3472 ffe7549 9ca3472 ef91ec6 9ca3472 ef91ec6 9ca3472 a9fd1f2 ef91ec6 9ca3472 ffe7549 ef91ec6 ffe7549 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import time
import gradio as gr
import numpy as np
import onnxruntime as ort
from pettingzoo.classic import connect_four_v3
from connectfour import ERROR_SCREEN
from models import MODEL_PATH
# poetry export -f requirements.txt --output requirements.txt --without-hashes
# gradio connectfour/app.py
session = ort.InferenceSession(str(MODEL_PATH), None)
demo = gr.Blocks()
column_count = 7
game_on_msg = "Game On"
def flatten_observation(obs):
flatten_action_mask = np.array(obs["action_mask"])
flatten_observation = np.reshape(obs["observation"], 2 * 6 * column_count)
flatten_obs = np.concatenate([flatten_action_mask, flatten_observation])
return flatten_obs[np.newaxis, ...].astype(np.float32)
def legal_moves(env, player_id):
return np.arange(column_count)[env.observe(player_id)["action_mask"] == 1]
def done(env):
return np.any(list(env.terminations.values()) + list(env.truncations.values()))
def get_state_msg(env, human):
if done(env):
end_message = "End of the game"
if env.rewards[human] > 0:
end_message += ": You WIN !!"
elif env.rewards[human] < 0:
end_message += ": You LOSE !!"
return end_message
return game_on_msg
def play(env, human, action=None):
try:
if human != env.agent_selection:
action = session.run(
["output"],
{
"obs": flatten_observation(env.observe(env.agent_selection)),
"state_ins": [],
},
)
action = int(np.argmax(action[0]))
if action not in legal_moves(env, env.agent_selection):
action = np.random.choice(legal_moves(env, env.agent_selection))
env.step(action)
return env, get_state_msg(env, human)
except Exception as e:
return env, f"Restart the Game"
def init_env(env, who_plays_first, human):
env.reset()
if who_plays_first != "You":
play(env, human)
return env
def error_screen():
with open(ERROR_SCREEN, "rb") as f:
error_screen = np.load(f)
return error_screen
def create_env():
return init_env(connect_four_v3.env(render_mode="rgb_array"), "You", "player_0")
with demo:
human = gr.State("player_0")
env = gr.State(create_env())
drop_token_btns = []
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# Let's Play Connect Four !")
who_plays_first = gr.Radio(
label="Who plays first", choices=["You", "Bot"], value="You"
)
reinitialize = gr.Button("New Game")
game_state = gr.Text(value=game_on_msg, interactive=False, label="Status")
with gr.Column(scale=1):
output = gr.Image(
label="Connect Four Grid",
type="numpy",
show_label=False,
value=error_screen(),
)
with gr.Row():
for i in range(column_count):
with gr.Column(min_width=20):
drop_token_btns.append(gr.Button("X", elem_id=i))
def reinit_game(env, who_plays_first, human):
env = init_env(env, who_plays_first, human)
return [
env,
env.agent_selection, # human
get_state_msg(env, human), # state_msg
gr.Checkbox.update(interactive=True), # who_plays_first
]
def on_render_change(env):
return env.render()
def wait(game_state_value):
if game_state_value == game_on_msg:
time.sleep(0.7)
return gr.Checkbox.update(interactive=False)
else:
return gr.Checkbox.update(interactive=True)
def bot(env, game_state_value, human):
if game_state_value == game_on_msg:
env, state_msg = play(env, human)
if state_msg == game_on_msg:
return state_msg, gr.Checkbox.update(interactive=False)
else:
return state_msg, gr.Checkbox.update(interactive=True)
return (
game_state_value,
gr.Checkbox.update(interactive=True),
)
def click_column(env, human, evt: gr.EventData):
env, state_msg = play(env, human, int(evt.target.elem_id))
return env, state_msg
def game_state_change(value):
return [gr.Button.update(interactive=value == game_on_msg)] * column_count
who_plays_first.change(
reinit_game,
[env, who_plays_first, human],
outputs=[env, human, game_state, who_plays_first],
).then(on_render_change, inputs=[env], outputs=[output])
reinitialize.click(
reinit_game,
[env, who_plays_first, human],
outputs=[env, human, game_state, who_plays_first],
).then(on_render_change, inputs=[env], outputs=[output])
for i in range(column_count):
drop_token_btns[i].click(
click_column,
inputs=[env, human],
outputs=[env, game_state],
).then(on_render_change, inputs=[env], outputs=[output]).then(
wait, inputs=[game_state], outputs=[who_plays_first]
).then(
bot, inputs=[env, game_state, human], outputs=[game_state, who_plays_first]
).then(
on_render_change, inputs=[env], outputs=[output]
)
game_state.change(
game_state_change,
game_state,
outputs=drop_token_btns,
)
demo.launch()
|