ClementBM's picture
fix bug and update model
a9fd1f2
import time
import gradio as gr
import numpy as np
import onnxruntime as ort
from pettingzoo.classic import connect_four_v3
from connectfour import ERROR_SCREEN
from models import MODEL_PATH
# poetry export -f requirements.txt --output requirements.txt --without-hashes
# gradio connectfour/app.py
session = ort.InferenceSession(str(MODEL_PATH), None)
demo = gr.Blocks()
column_count = 7
game_on_msg = "Game On"
def flatten_observation(obs):
flatten_action_mask = np.array(obs["action_mask"])
flatten_observation = np.reshape(obs["observation"], 2 * 6 * column_count)
flatten_obs = np.concatenate([flatten_action_mask, flatten_observation])
return flatten_obs[np.newaxis, ...].astype(np.float32)
def legal_moves(env, player_id):
return np.arange(column_count)[env.observe(player_id)["action_mask"] == 1]
def done(env):
return np.any(list(env.terminations.values()) + list(env.truncations.values()))
def get_state_msg(env, human):
if done(env):
end_message = "End of the game"
if env.rewards[human] > 0:
end_message += ": You WIN !!"
elif env.rewards[human] < 0:
end_message += ": You LOSE !!"
return end_message
return game_on_msg
def play(env, human, action=None):
try:
if human != env.agent_selection:
action = session.run(
["output"],
{
"obs": flatten_observation(env.observe(env.agent_selection)),
"state_ins": [],
},
)
action = int(np.argmax(action[0]))
if action not in legal_moves(env, env.agent_selection):
action = np.random.choice(legal_moves(env, env.agent_selection))
env.step(action)
return env, get_state_msg(env, human)
except Exception as e:
return env, f"Restart the Game"
def init_env(env, who_plays_first, human):
env.reset()
if who_plays_first != "You":
play(env, human)
return env
def error_screen():
with open(ERROR_SCREEN, "rb") as f:
error_screen = np.load(f)
return error_screen
def create_env():
return init_env(connect_four_v3.env(render_mode="rgb_array"), "You", "player_0")
with demo:
human = gr.State("player_0")
env = gr.State(create_env())
drop_token_btns = []
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# Let's Play Connect Four !")
who_plays_first = gr.Radio(
label="Who plays first", choices=["You", "Bot"], value="You"
)
reinitialize = gr.Button("New Game")
game_state = gr.Text(value=game_on_msg, interactive=False, label="Status")
with gr.Column(scale=1):
output = gr.Image(
label="Connect Four Grid",
type="numpy",
show_label=False,
value=error_screen(),
)
with gr.Row():
for i in range(column_count):
with gr.Column(min_width=20):
drop_token_btns.append(gr.Button("X", elem_id=i))
def reinit_game(env, who_plays_first, human):
env = init_env(env, who_plays_first, human)
return [
env,
env.agent_selection, # human
get_state_msg(env, human), # state_msg
gr.Checkbox.update(interactive=True), # who_plays_first
]
def on_render_change(env):
return env.render()
def wait(game_state_value):
if game_state_value == game_on_msg:
time.sleep(0.7)
return gr.Checkbox.update(interactive=False)
else:
return gr.Checkbox.update(interactive=True)
def bot(env, game_state_value, human):
if game_state_value == game_on_msg:
env, state_msg = play(env, human)
if state_msg == game_on_msg:
return state_msg, gr.Checkbox.update(interactive=False)
else:
return state_msg, gr.Checkbox.update(interactive=True)
return (
game_state_value,
gr.Checkbox.update(interactive=True),
)
def click_column(env, human, evt: gr.EventData):
env, state_msg = play(env, human, int(evt.target.elem_id))
return env, state_msg
def game_state_change(value):
return [gr.Button.update(interactive=value == game_on_msg)] * column_count
who_plays_first.change(
reinit_game,
[env, who_plays_first, human],
outputs=[env, human, game_state, who_plays_first],
).then(on_render_change, inputs=[env], outputs=[output])
reinitialize.click(
reinit_game,
[env, who_plays_first, human],
outputs=[env, human, game_state, who_plays_first],
).then(on_render_change, inputs=[env], outputs=[output])
for i in range(column_count):
drop_token_btns[i].click(
click_column,
inputs=[env, human],
outputs=[env, game_state],
).then(on_render_change, inputs=[env], outputs=[output]).then(
wait, inputs=[game_state], outputs=[who_plays_first]
).then(
bot, inputs=[env, game_state, human], outputs=[game_state, who_plays_first]
).then(
on_render_change, inputs=[env], outputs=[output]
)
game_state.change(
game_state_change,
game_state,
outputs=drop_token_btns,
)
demo.launch()