import time import gradio as gr import numpy as np import ray import ray.rllib.algorithms.ppo as ppo from pettingzoo.classic import connect_four_v3 from ray.tune import register_env from connectfour.checkpoint import CHECKPOINT from connectfour.training.models import Connect4MaskModel from connectfour.training.wrappers import Connect4Env POLICY_ID = "learned_v5" # poetry export -f requirements.txt --output requirements.txt --without-hashes class Connect4: def __init__(self, who_plays_first) -> None: ray.init(include_dashboard=False, ignore_reinit_error=True) # define how to make the environment env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array") # register that way to make the environment under an rllib name register_env("connect4", lambda config: Connect4Env(env_creator(config))) self.init_env(who_plays_first) def init_env(self, who_plays_first): orig_env = connect_four_v3.env(render_mode="rgb_array") self.env = Connect4Env(orig_env) self.done = False self.obs, info = self.env.reset() if who_plays_first == "You": self.human = self.player_id else: self.play() self.human = self.player_id return self.render_and_state def get_algo(self, checkpoint): config = ( ppo.PPOConfig() .environment("connect4") .framework("torch") .training(model={"custom_model": Connect4MaskModel}) ) config.explore = False self.algo = config.build() self.algo.restore(checkpoint) def play(self, action=None): if self.human != self.player_id: action = self.algo.compute_single_action( self.obs[self.player_id], policy_id=POLICY_ID ) if action not in self.legal_moves: action = np.random.choice(self.legal_moves) player_actions = {self.player_id: action} self.obs, self.reward, terminated, truncated, info = self.env.step( player_actions ) self.done = terminated["__all__"] or truncated["__all__"] return self.render_and_state @property def render_and_state(self): end_message = "End of the game" if self.done: if self.reward[self.human] > 0: end_message += ": You WIN !!" elif self.reward[self.human] < 0: end_message += ": You LOSE !!" return self.env.render(), end_message return self.env.render(), "Game On" @property def player_id(self): return list(self.obs.keys())[0] @property def legal_moves(self): return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1] demo = gr.Blocks() with demo: connect4 = Connect4("You") connect4.get_algo(str(CHECKPOINT)) with gr.Row(): with gr.Column(scale=1): gr.Markdown("# Let's Play Connect Four !") who_plays_first = gr.Radio( label="Who plays first", choices=["You", "Bot"], value="You" ) reinitialize = gr.Button("New Game") game_state = gr.Text(value="Game On", interactive=False, label="Status") with gr.Column(scale=1): output = gr.Image( label="Connect Four Grid", type="numpy", show_label=False, value=connect4.env.render(), ) with gr.Row(): with gr.Column(scale=1, min_width=20): drop_token0_btn = gr.Button("X") with gr.Column(scale=1, min_width=20): drop_token1_btn = gr.Button("X") with gr.Column(scale=1, min_width=20): drop_token2_btn = gr.Button("X") with gr.Column(scale=1, min_width=20): drop_token3_btn = gr.Button("X") with gr.Column(scale=1, min_width=20): drop_token4_btn = gr.Button("X") with gr.Column(scale=1, min_width=20): drop_token5_btn = gr.Button("X") with gr.Column(scale=1, min_width=20): drop_token6_btn = gr.Button("X") who_plays_first.change( connect4.init_env, who_plays_first, outputs=[output, game_state] ) def reinit_game(who_plays_first): output, game_state = connect4.init_env(who_plays_first) return output, game_state, gr.Checkbox.update(interactive=True) reinitialize.click( reinit_game, who_plays_first, outputs=[output, game_state, who_plays_first] ) def wait(game_state_value): if game_state_value == "Game On": time.sleep(1) return gr.Checkbox.update(interactive=False) else: return gr.Checkbox.update(interactive=True) def bot(game_state_value): if game_state_value == "Game On": rendered_env = connect4.play() return *rendered_env, gr.Checkbox.update(interactive=False) if rendered_env[ 1 ] == "Game On" else gr.Checkbox.update(interactive=True) return ( gr.Image.update(), game_state_value, gr.Checkbox.update(interactive=True), ) drop_token0_btn.click( lambda: connect4.play(0), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) drop_token1_btn.click( lambda: connect4.play(1), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) drop_token2_btn.click( lambda: connect4.play(2), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) drop_token3_btn.click( lambda: connect4.play(3), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) drop_token4_btn.click( lambda: connect4.play(4), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) drop_token5_btn.click( lambda: connect4.play(5), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) drop_token6_btn.click( lambda: connect4.play(6), outputs=[output, game_state], ).then( wait, inputs=[game_state], outputs=who_plays_first ).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) def game_state_change(value): if value == "Game On": return [ gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), ] else: return [ gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False), ] game_state.change( game_state_change, game_state, outputs=[ drop_token0_btn, drop_token1_btn, drop_token2_btn, drop_token3_btn, drop_token4_btn, drop_token5_btn, drop_token6_btn, ], ) demo.launch()