Spaces:
Runtime error
Runtime error
import time | |
import gradio as gr | |
import numpy as np | |
import ray | |
import ray.rllib.algorithms.ppo as ppo | |
from pettingzoo.classic import connect_four_v3 | |
from ray.tune import register_env | |
from connectfour.checkpoint import CHECKPOINT | |
from connectfour.training.models import Connect4MaskModel | |
from connectfour.training.wrappers import Connect4Env | |
POLICY_ID = "learned_v5" | |
# poetry export -f requirements.txt --output requirements.txt --without-hashes | |
class Connect4: | |
def __init__(self, who_plays_first) -> None: | |
ray.init(include_dashboard=False, ignore_reinit_error=True) | |
# define how to make the environment | |
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array") | |
# register that way to make the environment under an rllib name | |
register_env("connect4", lambda config: Connect4Env(env_creator(config))) | |
self.init_env(who_plays_first) | |
def init_env(self, who_plays_first): | |
orig_env = connect_four_v3.env(render_mode="rgb_array") | |
self.env = Connect4Env(orig_env) | |
self.done = False | |
self.obs, info = self.env.reset() | |
if who_plays_first == "You": | |
self.human = self.player_id | |
else: | |
self.play() | |
self.human = self.player_id | |
return self.render_and_state | |
def get_algo(self, checkpoint): | |
config = ( | |
ppo.PPOConfig() | |
.environment("connect4") | |
.framework("torch") | |
.training(model={"custom_model": Connect4MaskModel}) | |
) | |
config.explore = False | |
self.algo = config.build() | |
self.algo.restore(checkpoint) | |
def play(self, action=None): | |
if self.human != self.player_id: | |
action = self.algo.compute_single_action( | |
self.obs[self.player_id], policy_id=POLICY_ID | |
) | |
if action not in self.legal_moves: | |
action = np.random.choice(self.legal_moves) | |
player_actions = {self.player_id: action} | |
self.obs, self.reward, terminated, truncated, info = self.env.step( | |
player_actions | |
) | |
self.done = terminated["__all__"] or truncated["__all__"] | |
return self.render_and_state | |
def render_and_state(self): | |
end_message = "End of the game" | |
if self.done: | |
if self.reward[self.human] > 0: | |
end_message += ": You WIN !!" | |
elif self.reward[self.human] < 0: | |
end_message += ": You LOSE !!" | |
return self.env.render(), end_message | |
return self.env.render(), "Game On" | |
def player_id(self): | |
return list(self.obs.keys())[0] | |
def legal_moves(self): | |
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1] | |
demo = gr.Blocks() | |
with demo: | |
connect4 = Connect4("You") | |
connect4.get_algo(str(CHECKPOINT)) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("# Let's Play Connect Four !") | |
who_plays_first = gr.Radio( | |
label="Who plays first", choices=["You", "Bot"], value="You" | |
) | |
reinitialize = gr.Button("New Game") | |
game_state = gr.Text(value="Game On", interactive=False, label="Status") | |
with gr.Column(scale=1): | |
output = gr.Image( | |
label="Connect Four Grid", | |
type="numpy", | |
show_label=False, | |
value=connect4.env.render(), | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=20): | |
drop_token0_btn = gr.Button("X") | |
with gr.Column(scale=1, min_width=20): | |
drop_token1_btn = gr.Button("X") | |
with gr.Column(scale=1, min_width=20): | |
drop_token2_btn = gr.Button("X") | |
with gr.Column(scale=1, min_width=20): | |
drop_token3_btn = gr.Button("X") | |
with gr.Column(scale=1, min_width=20): | |
drop_token4_btn = gr.Button("X") | |
with gr.Column(scale=1, min_width=20): | |
drop_token5_btn = gr.Button("X") | |
with gr.Column(scale=1, min_width=20): | |
drop_token6_btn = gr.Button("X") | |
who_plays_first.change( | |
connect4.init_env, who_plays_first, outputs=[output, game_state] | |
) | |
def reinit_game(who_plays_first): | |
output, game_state = connect4.init_env(who_plays_first) | |
return output, game_state, gr.Checkbox.update(interactive=True) | |
reinitialize.click( | |
reinit_game, who_plays_first, outputs=[output, game_state, who_plays_first] | |
) | |
def wait(game_state_value): | |
if game_state_value == "Game On": | |
time.sleep(1) | |
return gr.Checkbox.update(interactive=False) | |
else: | |
return gr.Checkbox.update(interactive=True) | |
def bot(game_state_value): | |
if game_state_value == "Game On": | |
rendered_env = connect4.play() | |
return *rendered_env, gr.Checkbox.update(interactive=False) if rendered_env[ | |
1 | |
] == "Game On" else gr.Checkbox.update(interactive=True) | |
return ( | |
gr.Image.update(), | |
game_state_value, | |
gr.Checkbox.update(interactive=True), | |
) | |
drop_token0_btn.click( | |
lambda: connect4.play(0), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
drop_token1_btn.click( | |
lambda: connect4.play(1), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
drop_token2_btn.click( | |
lambda: connect4.play(2), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
drop_token3_btn.click( | |
lambda: connect4.play(3), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
drop_token4_btn.click( | |
lambda: connect4.play(4), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
drop_token5_btn.click( | |
lambda: connect4.play(5), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
drop_token6_btn.click( | |
lambda: connect4.play(6), | |
outputs=[output, game_state], | |
).then( | |
wait, inputs=[game_state], outputs=who_plays_first | |
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first]) | |
def game_state_change(value): | |
if value == "Game On": | |
return [ | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
gr.Button.update(interactive=True), | |
] | |
else: | |
return [ | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
gr.Button.update(interactive=False), | |
] | |
game_state.change( | |
game_state_change, | |
game_state, | |
outputs=[ | |
drop_token0_btn, | |
drop_token1_btn, | |
drop_token2_btn, | |
drop_token3_btn, | |
drop_token4_btn, | |
drop_token5_btn, | |
drop_token6_btn, | |
], | |
) | |
demo.launch() | |