Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
import numpy as np | |
import onnxruntime as ort | |
from pettingzoo.classic import connect_four_v3 | |
from connectfour import ERROR_SCREEN | |
from models import MODEL_PATH | |
# poetry export -f requirements.txt --output requirements.txt --without-hashes | |
# gradio connectfour/app.py | |
session = ort.InferenceSession(str(MODEL_PATH), None) | |
demo = gr.Blocks() | |
column_count = 7 | |
game_on_msg = "Game On" | |
def flatten_observation(obs): | |
flatten_action_mask = np.array(obs["action_mask"]) | |
flatten_observation = np.reshape(obs["observation"], 2 * 6 * column_count) | |
flatten_obs = np.concatenate([flatten_action_mask, flatten_observation]) | |
return flatten_obs[np.newaxis, ...].astype(np.float32) | |
def legal_moves(env, player_id): | |
return np.arange(column_count)[env.observe(player_id)["action_mask"] == 1] | |
def done(env): | |
return np.any(list(env.terminations.values()) + list(env.truncations.values())) | |
def get_state_msg(env, human): | |
if done(env): | |
end_message = "End of the game" | |
if env.rewards[human] > 0: | |
end_message += ": You WIN !!" | |
elif env.rewards[human] < 0: | |
end_message += ": You LOSE !!" | |
return end_message | |
return game_on_msg | |
def play(env, human, action=None): | |
try: | |
if human != env.agent_selection: | |
action = session.run( | |
["output"], | |
{ | |
"obs": flatten_observation(env.observe(env.agent_selection)), | |
"state_ins": [], | |
}, | |
) | |
action = int(np.argmax(action[0])) | |
if action not in legal_moves(env, env.agent_selection): | |
action = np.random.choice(legal_moves(env, env.agent_selection)) | |
env.step(action) | |
return env, get_state_msg(env, human) | |
except Exception as e: | |
return env, f"Restart the Game" | |
def init_env(env, who_plays_first, human): | |
env.reset() | |
if who_plays_first != "You": | |
play(env, human) | |
return env | |
def error_screen(): | |
with open(ERROR_SCREEN, "rb") as f: | |
error_screen = np.load(f) | |
return error_screen | |
def create_env(): | |
return init_env(connect_four_v3.env(render_mode="rgb_array"), "You", "player_0") | |
with demo: | |
human = gr.State("player_0") | |
env = gr.State(create_env()) | |
drop_token_btns = [] | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("# Let's Play Connect Four !") | |
who_plays_first = gr.Radio( | |
label="Who plays first", choices=["You", "Bot"], value="You" | |
) | |
reinitialize = gr.Button("New Game") | |
game_state = gr.Text(value=game_on_msg, interactive=False, label="Status") | |
with gr.Column(scale=1): | |
output = gr.Image( | |
label="Connect Four Grid", | |
type="numpy", | |
show_label=False, | |
value=error_screen(), | |
) | |
with gr.Row(): | |
for i in range(column_count): | |
with gr.Column(min_width=20): | |
drop_token_btns.append(gr.Button("X", elem_id=i)) | |
def reinit_game(env, who_plays_first, human): | |
env = init_env(env, who_plays_first, human) | |
return [ | |
env, | |
env.agent_selection, # human | |
get_state_msg(env, human), # state_msg | |
gr.Checkbox.update(interactive=True), # who_plays_first | |
] | |
def on_render_change(env): | |
return env.render() | |
def wait(game_state_value): | |
if game_state_value == game_on_msg: | |
time.sleep(0.7) | |
return gr.Checkbox.update(interactive=False) | |
else: | |
return gr.Checkbox.update(interactive=True) | |
def bot(env, game_state_value, human): | |
if game_state_value == game_on_msg: | |
env, state_msg = play(env, human) | |
if state_msg == game_on_msg: | |
return state_msg, gr.Checkbox.update(interactive=False) | |
else: | |
return state_msg, gr.Checkbox.update(interactive=True) | |
return ( | |
game_state_value, | |
gr.Checkbox.update(interactive=True), | |
) | |
def click_column(env, human, evt: gr.EventData): | |
env, state_msg = play(env, human, int(evt.target.elem_id)) | |
return env, state_msg | |
def game_state_change(value): | |
return [gr.Button.update(interactive=value == game_on_msg)] * column_count | |
who_plays_first.change( | |
reinit_game, | |
[env, who_plays_first, human], | |
outputs=[env, human, game_state, who_plays_first], | |
).then(on_render_change, inputs=[env], outputs=[output]) | |
reinitialize.click( | |
reinit_game, | |
[env, who_plays_first, human], | |
outputs=[env, human, game_state, who_plays_first], | |
).then(on_render_change, inputs=[env], outputs=[output]) | |
for i in range(column_count): | |
drop_token_btns[i].click( | |
click_column, | |
inputs=[env, human], | |
outputs=[env, game_state], | |
).then(on_render_change, inputs=[env], outputs=[output]).then( | |
wait, inputs=[game_state], outputs=[who_plays_first] | |
).then( | |
bot, inputs=[env, game_state, human], outputs=[game_state, who_plays_first] | |
).then( | |
on_render_change, inputs=[env], outputs=[output] | |
) | |
game_state.change( | |
game_state_change, | |
game_state, | |
outputs=drop_token_btns, | |
) | |
demo.launch() | |