ClementBM's picture
remove dependencies
d1757d4
import time
import gradio as gr
import numpy as np
import ray
import ray.rllib.algorithms.ppo as ppo
from pettingzoo.classic import connect_four_v3
from ray.tune import register_env
from connectfour.checkpoint import CHECKPOINT
from connectfour.training.models import Connect4MaskModel
from connectfour.training.wrappers import Connect4Env
POLICY_ID = "learned_v5"
# poetry export -f requirements.txt --output requirements.txt --without-hashes
class Connect4:
def __init__(self, who_plays_first) -> None:
ray.init(include_dashboard=False, ignore_reinit_error=True)
# define how to make the environment
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
# register that way to make the environment under an rllib name
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
self.init_env(who_plays_first)
def init_env(self, who_plays_first):
orig_env = connect_four_v3.env(render_mode="rgb_array")
self.env = Connect4Env(orig_env)
self.done = False
self.obs, info = self.env.reset()
if who_plays_first == "You":
self.human = self.player_id
else:
self.play()
self.human = self.player_id
return self.render_and_state
def get_algo(self, checkpoint):
config = (
ppo.PPOConfig()
.environment("connect4")
.framework("torch")
.training(model={"custom_model": Connect4MaskModel})
)
config.explore = False
self.algo = config.build()
self.algo.restore(checkpoint)
def play(self, action=None):
if self.human != self.player_id:
action = self.algo.compute_single_action(
self.obs[self.player_id], policy_id=POLICY_ID
)
if action not in self.legal_moves:
action = np.random.choice(self.legal_moves)
player_actions = {self.player_id: action}
self.obs, self.reward, terminated, truncated, info = self.env.step(
player_actions
)
self.done = terminated["__all__"] or truncated["__all__"]
return self.render_and_state
@property
def render_and_state(self):
end_message = "End of the game"
if self.done:
if self.reward[self.human] > 0:
end_message += ": You WIN !!"
elif self.reward[self.human] < 0:
end_message += ": You LOSE !!"
return self.env.render(), end_message
return self.env.render(), "Game On"
@property
def player_id(self):
return list(self.obs.keys())[0]
@property
def legal_moves(self):
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
demo = gr.Blocks()
with demo:
connect4 = Connect4("You")
connect4.get_algo(str(CHECKPOINT))
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# Let's Play Connect Four !")
who_plays_first = gr.Radio(
label="Who plays first", choices=["You", "Bot"], value="You"
)
reinitialize = gr.Button("New Game")
game_state = gr.Text(value="Game On", interactive=False, label="Status")
with gr.Column(scale=1):
output = gr.Image(
label="Connect Four Grid",
type="numpy",
show_label=False,
value=connect4.env.render(),
)
with gr.Row():
with gr.Column(scale=1, min_width=20):
drop_token0_btn = gr.Button("X")
with gr.Column(scale=1, min_width=20):
drop_token1_btn = gr.Button("X")
with gr.Column(scale=1, min_width=20):
drop_token2_btn = gr.Button("X")
with gr.Column(scale=1, min_width=20):
drop_token3_btn = gr.Button("X")
with gr.Column(scale=1, min_width=20):
drop_token4_btn = gr.Button("X")
with gr.Column(scale=1, min_width=20):
drop_token5_btn = gr.Button("X")
with gr.Column(scale=1, min_width=20):
drop_token6_btn = gr.Button("X")
who_plays_first.change(
connect4.init_env, who_plays_first, outputs=[output, game_state]
)
def reinit_game(who_plays_first):
output, game_state = connect4.init_env(who_plays_first)
return output, game_state, gr.Checkbox.update(interactive=True)
reinitialize.click(
reinit_game, who_plays_first, outputs=[output, game_state, who_plays_first]
)
def wait(game_state_value):
if game_state_value == "Game On":
time.sleep(1)
return gr.Checkbox.update(interactive=False)
else:
return gr.Checkbox.update(interactive=True)
def bot(game_state_value):
if game_state_value == "Game On":
rendered_env = connect4.play()
return *rendered_env, gr.Checkbox.update(interactive=False) if rendered_env[
1
] == "Game On" else gr.Checkbox.update(interactive=True)
return (
gr.Image.update(),
game_state_value,
gr.Checkbox.update(interactive=True),
)
drop_token0_btn.click(
lambda: connect4.play(0),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
drop_token1_btn.click(
lambda: connect4.play(1),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
drop_token2_btn.click(
lambda: connect4.play(2),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
drop_token3_btn.click(
lambda: connect4.play(3),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
drop_token4_btn.click(
lambda: connect4.play(4),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
drop_token5_btn.click(
lambda: connect4.play(5),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
drop_token6_btn.click(
lambda: connect4.play(6),
outputs=[output, game_state],
).then(
wait, inputs=[game_state], outputs=who_plays_first
).then(bot, inputs=[game_state], outputs=[output, game_state, who_plays_first])
def game_state_change(value):
if value == "Game On":
return [
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
gr.Button.update(interactive=True),
]
else:
return [
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
]
game_state.change(
game_state_change,
game_state,
outputs=[
drop_token0_btn,
drop_token1_btn,
drop_token2_btn,
drop_token3_btn,
drop_token4_btn,
drop_token5_btn,
drop_token6_btn,
],
)
demo.launch()