viz-gpt2-stockfish-debug / src /play_interface.py
Xmaster6y's picture
state output
128d193 unverified
raw
history blame
4.52 kB
"""Interface to play against the model.
"""
import huggingface_hub
import chess
import chess.svg
import uuid
import random
import wandb
import gradio as gr
from . import constants
model_name = "yp-edu/gpt2-stockfish-debug"
headers = {"X-Wait-For-Model": "true"}
client = huggingface_hub.InferenceClient(
model=model_name, headers=headers
)
inference_fn = client.text_generation
def plot_board(
board: chess.Board,
):
try:
last_move = board.peek()
arrows = [(last_move.from_square, last_move.to_square)]
except IndexError:
arrows = []
if board.is_check():
check = board.king(board.turn)
else:
check = None
svg_board = chess.svg.board(
board,
check=check,
size=350,
arrows=arrows,
)
id = str(uuid.uuid4())
with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg"
def render_board(
current_board: chess.Board,
):
fen = current_board.fen()
pgn = current_board.root().variation_san(current_board.move_stack)
image_board = plot_board(current_board)
return fen, pgn, image_board
def play_user_move(
uci_move: str,
current_board: chess.Board,
):
current_board.push_uci(uci_move)
return current_board
def play_ai_move(
current_board: chess.Board,
temperature: float = 0.1,
top_k: int = 3,
):
uci_move = inference_fn(
prompt=f"FEN: {current_board.fen()}\nMOVE:",
temperature=temperature,
top_k=top_k,
)
current_board.push_uci(uci_move.strip())
return current_board
def try_play_move(
username: str,
move_to_play: str,
current_board: chess.Board,
):
if current_board.is_game_over():
gr.Warning("The game is already over")
return *render_board(current_board), current_board
try:
current_board = play_user_move(move_to_play, current_board)
if current_board.is_game_over():
gr.Info(f"Congratulations, {username}!")
with wandb.init(project="gpt2-stockfish-debug", entity="yp-edu") as run:
run.log(
{
"username": username,
"winin": current_board.fullmove_number,
"pgn": current_board.variation_san(current_board.move_stack),
}
)
run.finish()
return *render_board(current_board), current_board
except:
gr.Warning("Invalid move")
return *render_board(current_board), current_board
temperature_retries = [
(i+1)/10 for i in range(10)
]
for temperature in temperature_retries:
try:
current_board = play_ai_move(current_board, temperature=temperature)
break
except:
gr.Warning(f"AI move failed with temperature {temperature}")
else:
gr.Warning("AI move failed with all temperatures")
current_board.pop()
return *render_board(current_board), current_board
with gr.Blocks() as interface:
with gr.Column():
username = gr.Textbox(
label="Username to record on leaderboard (should you win)",
lines=1,
max_lines=1,
value="",
)
current_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
current_pgn = gr.Textbox(
label="Action sequence",
lines=1,
value="",
)
with gr.Row():
move_to_play = gr.Textbox(
label="Move to play (UCI)",
lines=1,
max_lines=1,
value="",
)
play_button = gr.Button("Play")
with gr.Column():
image_board = gr.Image(label="Board")
static_inputs = [
username,
move_to_play,
]
static_outputs = [
current_fen,
current_pgn,
image_board,
]
is_ai_white = random.choice([True, False])
init_board = chess.Board()
if is_ai_white:
init_board = play_ai_move(init_board)
state_board = gr.State(value=init_board)
play_button.click(
try_play_move,
inputs=[*static_inputs, state_board],
outputs=[*static_outputs, gr.State()],
)
interface.load(render_board, inputs=[state_board], outputs=[*static_outputs])