Spaces:
Sleeping
Sleeping
"""Interface to play against the model. | |
""" | |
import huggingface_hub | |
import chess | |
import chess.svg | |
import uuid | |
import random | |
import wandb | |
import gradio as gr | |
from . import constants | |
model_name = "yp-edu/gpt2-stockfish-debug" | |
headers = {"X-Wait-For-Model": "true"} | |
client = huggingface_hub.InferenceClient( | |
model=model_name, headers=headers | |
) | |
inference_fn = client.text_generation | |
def plot_board( | |
board: chess.Board, | |
): | |
try: | |
last_move = board.peek() | |
arrows = [(last_move.from_square, last_move.to_square)] | |
except IndexError: | |
arrows = [] | |
if board.is_check(): | |
check = board.king(board.turn) | |
else: | |
check = None | |
svg_board = chess.svg.board( | |
board, | |
check=check, | |
size=350, | |
arrows=arrows, | |
) | |
id = str(uuid.uuid4()) | |
with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f: | |
f.write(svg_board) | |
return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg" | |
def render_board( | |
current_board: chess.Board, | |
): | |
fen = current_board.fen() | |
pgn = current_board.root().variation_san(current_board.move_stack) | |
image_board = plot_board(current_board) | |
return fen, pgn, image_board | |
def play_user_move( | |
uci_move: str, | |
current_board: chess.Board, | |
): | |
current_board.push_uci(uci_move) | |
return current_board | |
def play_ai_move( | |
current_board: chess.Board, | |
temperature: float = 0.1, | |
top_k: int = 3, | |
): | |
uci_move = inference_fn( | |
prompt=f"FEN: {current_board.fen()}\nMOVE:", | |
temperature=temperature, | |
top_k=top_k, | |
) | |
current_board.push_uci(uci_move.strip()) | |
return current_board | |
def try_play_move( | |
username: str, | |
move_to_play: str, | |
current_board: chess.Board, | |
): | |
if current_board.is_game_over(): | |
gr.Warning("The game is already over") | |
return *render_board(current_board), current_board | |
try: | |
current_board = play_user_move(move_to_play, current_board) | |
if current_board.is_game_over(): | |
gr.Info(f"Congratulations, {username}!") | |
with wandb.init(project="gpt2-stockfish-debug", entity="yp-edu") as run: | |
run.log( | |
{ | |
"username": username, | |
"winin": current_board.fullmove_number, | |
"pgn": current_board.variation_san(current_board.move_stack), | |
} | |
) | |
run.finish() | |
return *render_board(current_board), current_board | |
except: | |
gr.Warning("Invalid move") | |
return *render_board(current_board), current_board | |
temperature_retries = [ | |
(i+1)/10 for i in range(10) | |
] | |
for temperature in temperature_retries: | |
try: | |
current_board = play_ai_move(current_board, temperature=temperature) | |
break | |
except: | |
gr.Warning(f"AI move failed with temperature {temperature}") | |
else: | |
gr.Warning("AI move failed with all temperatures") | |
current_board.pop() | |
return *render_board(current_board), current_board | |
with gr.Blocks() as interface: | |
with gr.Column(): | |
username = gr.Textbox( | |
label="Username to record on leaderboard (should you win)", | |
lines=1, | |
max_lines=1, | |
value="", | |
) | |
current_fen = gr.Textbox( | |
label="Board FEN", | |
lines=1, | |
max_lines=1, | |
value=chess.STARTING_FEN, | |
) | |
current_pgn = gr.Textbox( | |
label="Action sequence", | |
lines=1, | |
value="", | |
) | |
with gr.Row(): | |
move_to_play = gr.Textbox( | |
label="Move to play (UCI)", | |
lines=1, | |
max_lines=1, | |
value="", | |
) | |
play_button = gr.Button("Play") | |
with gr.Column(): | |
image_board = gr.Image(label="Board") | |
static_inputs = [ | |
username, | |
move_to_play, | |
] | |
static_outputs = [ | |
current_fen, | |
current_pgn, | |
image_board, | |
] | |
is_ai_white = random.choice([True, False]) | |
init_board = chess.Board() | |
if is_ai_white: | |
init_board = play_ai_move(init_board) | |
state_board = gr.State(value=init_board) | |
play_button.click( | |
try_play_move, | |
inputs=[*static_inputs, state_board], | |
outputs=[*static_outputs, gr.State()], | |
) | |
interface.load(render_board, inputs=[state_board], outputs=[*static_outputs]) | |