Spaces:
Running
Running
import streamlit as st | |
import chess | |
import chess.svg | |
import chess.pgn | |
import base64 | |
import torch | |
import requests | |
from o1.agent import Agent | |
import random | |
import re | |
st.set_page_config(page_title="Play Chess vs o1", layout="centered") | |
HF_MODEL_URL = "https://huggingface.co/FlameF0X/o1/resolve/main/o1_agent.pth" | |
def load_o1(): | |
import tempfile | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
response = requests.get(HF_MODEL_URL) | |
if response.status_code != 200: | |
raise RuntimeError("Failed to download model from Hugging Face.") | |
tmp.write(response.content) | |
model_path = tmp.name | |
o1 = Agent() | |
o1.model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
o1.model.eval() | |
return o1 | |
def parse_move_input(move_input, board): | |
"""Parse various move input formats and return a chess.Move object""" | |
if not move_input: | |
return None | |
# Clean and normalize input | |
move_input = move_input.strip() | |
# If it's already UCI format (e2e4, e7e5, etc.), try it directly | |
if len(move_input) >= 4 and move_input[2:4].isalnum(): | |
try: | |
move = chess.Move.from_uci(move_input.lower()) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
# Try to parse as PGN/algebraic notation (E4, Nf3, O-O, etc.) | |
# Handle both uppercase and lowercase | |
try: | |
# First try the input as-is | |
move = board.parse_san(move_input) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
try: | |
# Try with different case combinations | |
variations = [ | |
move_input.upper(), | |
move_input.lower(), | |
move_input.capitalize(), | |
] | |
for variation in variations: | |
try: | |
move = board.parse_san(variation) | |
if move in board.legal_moves: | |
return move | |
except: | |
continue | |
except: | |
pass | |
# Try some common format fixes | |
# Handle cases like "e4" -> "e4" (pawn moves) | |
if len(move_input) == 2 and move_input[0].lower() in 'abcdefgh' and move_input[1] in '12345678': | |
try: | |
move = board.parse_san(move_input.lower()) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
# Handle castling notation variations | |
castling_variations = { | |
'0-0': 'O-O', | |
'0-0-0': 'O-O-O', | |
'oo': 'O-O', | |
'ooo': 'O-O-O', | |
'o-o': 'O-O', | |
'o-o-o': 'O-O-O', | |
} | |
lower_input = move_input.lower() | |
if lower_input in castling_variations: | |
try: | |
move = board.parse_san(castling_variations[lower_input]) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
return None | |
def get_o1_move(o1, board): | |
"""Get the best move from o1 with proper error handling""" | |
try: | |
policy_logits, _ = o1.predict(board) | |
legal_moves = list(board.legal_moves) | |
if not legal_moves: | |
return None | |
# Method 1: Try to use policy logits if they match legal moves count | |
if hasattr(policy_logits, 'shape') and len(policy_logits.shape) > 1: | |
if policy_logits.shape[1] >= len(legal_moves): | |
# Map policy scores to legal moves | |
move_scores = [] | |
for i, move in enumerate(legal_moves): | |
if i < policy_logits.shape[1]: | |
score = policy_logits[0, i].item() | |
else: | |
score = random.random() # Random score for unmapped moves | |
move_scores.append((score, move)) | |
# Sort by score and return best move | |
move_scores.sort(reverse=True) | |
return move_scores[0][1] | |
# Method 2: If policy logits don't work, use a simple heuristic | |
# Prioritize captures, then center control, then random | |
scored_moves = [] | |
for move in legal_moves: | |
score = 0 | |
# Prioritize captures | |
if board.is_capture(move): | |
score += 10 | |
# Prioritize moves to center squares | |
to_square = move.to_square | |
file = chess.square_file(to_square) | |
rank = chess.square_rank(to_square) | |
if 2 <= file <= 5 and 2 <= rank <= 5: # Center squares | |
score += 5 | |
# Add some randomness | |
score += random.random() | |
scored_moves.append((score, move)) | |
scored_moves.sort(reverse=True) | |
return scored_moves[0][1] | |
except Exception as e: | |
st.error(f"Error in o1 prediction: {e}") | |
# Fallback to random legal move | |
legal_moves = list(board.legal_moves) | |
return random.choice(legal_moves) if legal_moves else None | |
if "board" not in st.session_state: | |
st.session_state.board = chess.Board() | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
# Load o1 with error handling | |
try: | |
o1 = load_o1() | |
o1_loaded = True | |
except Exception as e: | |
st.error(f"Failed to load o1: {e}") | |
o1 = None | |
o1_loaded = False | |
board = st.session_state.board | |
history = st.session_state.history | |
st.title("♟️ Play Chess vs o1") | |
if not o1_loaded: | |
st.warning("o1 failed to load. The app will use random moves for the AI.") | |
if st.button("Reset Game"): | |
st.session_state.board = chess.Board() | |
st.session_state.history = [] | |
st.rerun() | |
board_placeholder = st.empty() | |
def render_board(): | |
try: | |
last_move = board.peek() if board.move_stack else None | |
svg_board = chess.svg.board(board=board, lastmove=last_move, size=400) | |
board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Error rendering board: {e}") | |
render_board() | |
# Game status | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write(f"**Turn:** {'White' if board.turn == chess.WHITE else 'Black'}") | |
with col2: | |
if board.is_check(): | |
st.write("**Check!**") | |
# Player move input (White) | |
if not board.is_game_over() and board.turn == chess.WHITE: | |
st.write("### Your Turn (White)") | |
# Show legal moves in both formats for reference | |
legal_moves = list(board.legal_moves) | |
legal_moves_uci = [move.uci() for move in legal_moves] | |
legal_moves_san = [] | |
for move in legal_moves: | |
try: | |
san = board.san(move) | |
legal_moves_san.append(san) | |
except: | |
legal_moves_san.append(move.uci()) | |
with st.expander("Show legal moves"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("**Algebraic notation:**") | |
st.write(", ".join(sorted(legal_moves_san))) | |
with col2: | |
st.write("**UCI notation:**") | |
st.write(", ".join(sorted(legal_moves_uci))) | |
user_move = st.text_input("Enter your move (e.g., E4, Nf3, e2e4, O-O):", key="move_input", | |
help="You can use algebraic notation (E4, Nf3) or UCI notation (e2e4). Case doesn't matter!") | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Submit Move"): | |
if user_move: | |
parsed_move = parse_move_input(user_move, board) | |
if parsed_move: | |
try: | |
board.push(parsed_move) | |
history.append(parsed_move.uci()) | |
st.success(f"You played: {board.san(parsed_move)} ({parsed_move.uci()})") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error making move: {e}") | |
else: | |
st.warning(f"Invalid move: '{user_move}'. Please check the legal moves above.") | |
else: | |
st.warning("Please enter a move.") | |
with col2: | |
if st.button("Random Move"): | |
if legal_moves: | |
random_move = random.choice(legal_moves) | |
board.push(random_move) | |
history.append(random_move.uci()) | |
st.rerun() | |
# o1 move (Black) | |
if not board.is_game_over() and board.turn == chess.BLACK: | |
st.write("### o1's Turn (Black)") | |
with st.spinner("o1 is thinking..."): | |
try: | |
if o1_loaded and o1: | |
best_move = get_o1_move(o1, board) | |
else: | |
# Fallback to random move if o1 not loaded | |
legal_moves = list(board.legal_moves) | |
best_move = random.choice(legal_moves) if legal_moves else None | |
if best_move and best_move in board.legal_moves: | |
move_san = board.san(best_move) | |
board.push(best_move) | |
history.append(best_move.uci()) | |
st.success(f"o1 played: {move_san} ({best_move.uci()})") | |
st.rerun() | |
else: | |
st.error("o1 couldn't find a valid move") | |
except Exception as e: | |
st.error(f"Error during o1 move: {e}") | |
# Game history | |
if history: | |
st.write("### Game History") | |
# Create PGN | |
try: | |
game = chess.pgn.Game() | |
game.headers["Event"] = "Human vs o1" | |
game.headers["White"] = "Human" | |
game.headers["Black"] = "o1" | |
node = game | |
temp_board = chess.Board() | |
for uci in history: | |
move = chess.Move.from_uci(uci) | |
if move in temp_board.legal_moves: | |
node = node.add_main_variation(move) | |
temp_board.push(move) | |
else: | |
st.warning(f"Invalid move in history: {uci}") | |
break | |
st.code(str(game), language="pgn") | |
except Exception as e: | |
st.error(f"Error generating PGN: {e}") | |
# Fallback: show move list | |
move_pairs = [] | |
for i in range(0, len(history), 2): | |
white_move = history[i] | |
black_move = history[i+1] if i+1 < len(history) else "" | |
move_pairs.append(f"{i//2 + 1}. {white_move} {black_move}") | |
st.code("\n".join(move_pairs)) | |
# Game over | |
if board.is_game_over(): | |
st.write("### Game Over!") | |
result = board.result() | |
outcome = board.outcome() | |
if result == "1-0": | |
st.success("White wins!") | |
elif result == "0-1": | |
st.error("Black wins!") | |
else: | |
st.info("Draw!") | |
st.write(f"**Result:** {result}") | |
st.write(f"**Termination:** {outcome.termination.name}") | |
if st.button("Start New Game"): | |
st.session_state.board = chess.Board() | |
st.session_state.history = [] | |
st.rerun() |