import time import gradio as gr import numpy as np import onnxruntime as ort from pettingzoo.classic import connect_four_v3 from connectfour import ERROR_SCREEN from models import MODEL_PATH # poetry export -f requirements.txt --output requirements.txt --without-hashes # gradio connectfour/app.py session = ort.InferenceSession(str(MODEL_PATH), None) demo = gr.Blocks() column_count = 7 game_on_msg = "Game On" def flatten_observation(obs): flatten_action_mask = np.array(obs["action_mask"]) flatten_observation = np.reshape(obs["observation"], 2 * 6 * column_count) flatten_obs = np.concatenate([flatten_action_mask, flatten_observation]) return flatten_obs[np.newaxis, ...].astype(np.float32) def legal_moves(env, player_id): return np.arange(column_count)[env.observe(player_id)["action_mask"] == 1] def done(env): return np.any(list(env.terminations.values()) + list(env.truncations.values())) def get_state_msg(env, human): if done(env): end_message = "End of the game" if env.rewards[human] > 0: end_message += ": You WIN !!" elif env.rewards[human] < 0: end_message += ": You LOSE !!" return end_message return game_on_msg def play(env, human, action=None): try: if human != env.agent_selection: action = session.run( ["output"], { "obs": flatten_observation(env.observe(env.agent_selection)), "state_ins": [], }, ) action = int(np.argmax(action[0])) if action not in legal_moves(env, env.agent_selection): action = np.random.choice(legal_moves(env, env.agent_selection)) env.step(action) return env, get_state_msg(env, human) except Exception as e: return env, f"Restart the Game" def init_env(env, who_plays_first, human): env.reset() if who_plays_first != "You": play(env, human) return env def error_screen(): with open(ERROR_SCREEN, "rb") as f: error_screen = np.load(f) return error_screen def create_env(): return init_env(connect_four_v3.env(render_mode="rgb_array"), "You", "player_0") with demo: human = gr.State("player_0") env = gr.State(create_env()) drop_token_btns = [] with gr.Row(): with gr.Column(scale=1): gr.Markdown("# Let's Play Connect Four !") who_plays_first = gr.Radio( label="Who plays first", choices=["You", "Bot"], value="You" ) reinitialize = gr.Button("New Game") game_state = gr.Text(value=game_on_msg, interactive=False, label="Status") with gr.Column(scale=1): output = gr.Image( label="Connect Four Grid", type="numpy", show_label=False, value=error_screen(), ) with gr.Row(): for i in range(column_count): with gr.Column(min_width=20): drop_token_btns.append(gr.Button("X", elem_id=i)) def reinit_game(env, who_plays_first, human): env = init_env(env, who_plays_first, human) return [ env, env.agent_selection, # human get_state_msg(env, human), # state_msg gr.Checkbox.update(interactive=True), # who_plays_first ] def on_render_change(env): return env.render() def wait(game_state_value): if game_state_value == game_on_msg: time.sleep(0.7) return gr.Checkbox.update(interactive=False) else: return gr.Checkbox.update(interactive=True) def bot(env, game_state_value, human): if game_state_value == game_on_msg: env, state_msg = play(env, human) if state_msg == game_on_msg: return state_msg, gr.Checkbox.update(interactive=False) else: return state_msg, gr.Checkbox.update(interactive=True) return ( game_state_value, gr.Checkbox.update(interactive=True), ) def click_column(env, human, evt: gr.EventData): env, state_msg = play(env, human, int(evt.target.elem_id)) return env, state_msg def game_state_change(value): return [gr.Button.update(interactive=value == game_on_msg)] * column_count who_plays_first.change( reinit_game, [env, who_plays_first, human], outputs=[env, human, game_state, who_plays_first], ).then(on_render_change, inputs=[env], outputs=[output]) reinitialize.click( reinit_game, [env, who_plays_first, human], outputs=[env, human, game_state, who_plays_first], ).then(on_render_change, inputs=[env], outputs=[output]) for i in range(column_count): drop_token_btns[i].click( click_column, inputs=[env, human], outputs=[env, game_state], ).then(on_render_change, inputs=[env], outputs=[output]).then( wait, inputs=[game_state], outputs=[who_plays_first] ).then( bot, inputs=[env, game_state, human], outputs=[game_state, who_plays_first] ).then( on_render_change, inputs=[env], outputs=[output] ) game_state.change( game_state_change, game_state, outputs=drop_token_btns, ) demo.launch()