File size: 5,491 Bytes
cd63926
ffe7549
 
9ca3472
 
 
ffe7549
9ca3472
 
ffe7549
d1757d4
7db569a
d1757d4
9ca3472
d1757d4
 
9ca3472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffe7549
9ca3472
 
 
ef91ec6
ffe7549
 
 
 
 
 
 
 
 
9ca3472
ffe7549
 
 
 
 
 
9ca3472
ffe7549
 
 
9ca3472
ef91ec6
 
ffe7549
9ca3472
 
 
 
 
 
 
 
ffe7549
9ca3472
 
ffe7549
 
9ca3472
ef91ec6
ffe7549
 
 
 
a9fd1f2
9ca3472
 
 
 
 
 
ffe7549
 
 
 
 
9ca3472
 
 
ffe7549
 
9ca3472
 
 
 
 
 
 
ef91ec6
9ca3472
 
 
 
 
 
 
ef91ec6
 
9ca3472
 
 
 
 
a9fd1f2
ef91ec6
9ca3472
 
ffe7549
 
 
 
ef91ec6
ffe7549
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import time

import gradio as gr
import numpy as np
import onnxruntime as ort
from pettingzoo.classic import connect_four_v3

from connectfour import ERROR_SCREEN
from models import MODEL_PATH

# poetry export -f requirements.txt --output requirements.txt --without-hashes
# gradio connectfour/app.py

session = ort.InferenceSession(str(MODEL_PATH), None)
demo = gr.Blocks()

column_count = 7
game_on_msg = "Game On"


def flatten_observation(obs):
    flatten_action_mask = np.array(obs["action_mask"])
    flatten_observation = np.reshape(obs["observation"], 2 * 6 * column_count)
    flatten_obs = np.concatenate([flatten_action_mask, flatten_observation])
    return flatten_obs[np.newaxis, ...].astype(np.float32)


def legal_moves(env, player_id):
    return np.arange(column_count)[env.observe(player_id)["action_mask"] == 1]


def done(env):
    return np.any(list(env.terminations.values()) + list(env.truncations.values()))


def get_state_msg(env, human):
    if done(env):
        end_message = "End of the game"
        if env.rewards[human] > 0:
            end_message += ": You WIN !!"
        elif env.rewards[human] < 0:
            end_message += ": You LOSE !!"
        return end_message

    return game_on_msg


def play(env, human, action=None):
    try:
        if human != env.agent_selection:
            action = session.run(
                ["output"],
                {
                    "obs": flatten_observation(env.observe(env.agent_selection)),
                    "state_ins": [],
                },
            )
            action = int(np.argmax(action[0]))

        if action not in legal_moves(env, env.agent_selection):
            action = np.random.choice(legal_moves(env, env.agent_selection))

        env.step(action)
        return env, get_state_msg(env, human)
    except Exception as e:
        return env, f"Restart the Game"


def init_env(env, who_plays_first, human):
    env.reset()

    if who_plays_first != "You":
        play(env, human)

    return env


def error_screen():
    with open(ERROR_SCREEN, "rb") as f:
        error_screen = np.load(f)
    return error_screen


def create_env():
    return init_env(connect_four_v3.env(render_mode="rgb_array"), "You", "player_0")


with demo:
    human = gr.State("player_0")
    env = gr.State(create_env())

    drop_token_btns = []

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("# Let's Play Connect Four !")

            who_plays_first = gr.Radio(
                label="Who plays first", choices=["You", "Bot"], value="You"
            )
            reinitialize = gr.Button("New Game")
            game_state = gr.Text(value=game_on_msg, interactive=False, label="Status")

        with gr.Column(scale=1):
            output = gr.Image(
                label="Connect Four Grid",
                type="numpy",
                show_label=False,
                value=error_screen(),
            )

            with gr.Row():
                for i in range(column_count):
                    with gr.Column(min_width=20):
                        drop_token_btns.append(gr.Button("X", elem_id=i))

    def reinit_game(env, who_plays_first, human):
        env = init_env(env, who_plays_first, human)
        return [
            env,
            env.agent_selection,  # human
            get_state_msg(env, human),  # state_msg
            gr.Checkbox.update(interactive=True),  # who_plays_first
        ]

    def on_render_change(env):
        return env.render()

    def wait(game_state_value):
        if game_state_value == game_on_msg:
            time.sleep(0.7)
            return gr.Checkbox.update(interactive=False)
        else:
            return gr.Checkbox.update(interactive=True)

    def bot(env, game_state_value, human):
        if game_state_value == game_on_msg:
            env, state_msg = play(env, human)
            if state_msg == game_on_msg:
                return state_msg, gr.Checkbox.update(interactive=False)
            else:
                return state_msg, gr.Checkbox.update(interactive=True)
        return (
            game_state_value,
            gr.Checkbox.update(interactive=True),
        )

    def click_column(env, human, evt: gr.EventData):
        env, state_msg = play(env, human, int(evt.target.elem_id))
        return env, state_msg

    def game_state_change(value):
        return [gr.Button.update(interactive=value == game_on_msg)] * column_count

    who_plays_first.change(
        reinit_game,
        [env, who_plays_first, human],
        outputs=[env, human, game_state, who_plays_first],
    ).then(on_render_change, inputs=[env], outputs=[output])

    reinitialize.click(
        reinit_game,
        [env, who_plays_first, human],
        outputs=[env, human, game_state, who_plays_first],
    ).then(on_render_change, inputs=[env], outputs=[output])

    for i in range(column_count):
        drop_token_btns[i].click(
            click_column,
            inputs=[env, human],
            outputs=[env, game_state],
        ).then(on_render_change, inputs=[env], outputs=[output]).then(
            wait, inputs=[game_state], outputs=[who_plays_first]
        ).then(
            bot, inputs=[env, game_state, human], outputs=[game_state, who_plays_first]
        ).then(
            on_render_change, inputs=[env], outputs=[output]
        )

    game_state.change(
        game_state_change,
        game_state,
        outputs=drop_token_btns,
    )

demo.launch()