play-with-o1 / src /streamlit_app.py
FlameF0X's picture
Update src/streamlit_app.py
f162225 verified
import streamlit as st
import chess
import chess.svg
import chess.pgn
import base64
import torch
import requests
from o1.agent import Agent
import random
import re
st.set_page_config(page_title="Play Chess vs o1", layout="centered")
HF_MODEL_URL = "https://huggingface.co/FlameF0X/o1/resolve/main/o1_agent.pth"
@st.cache_resource
def load_o1():
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as tmp:
response = requests.get(HF_MODEL_URL)
if response.status_code != 200:
raise RuntimeError("Failed to download model from Hugging Face.")
tmp.write(response.content)
model_path = tmp.name
o1 = Agent()
o1.model.load_state_dict(torch.load(model_path, map_location='cpu'))
o1.model.eval()
return o1
def parse_move_input(move_input, board):
"""Parse various move input formats and return a chess.Move object"""
if not move_input:
return None
# Clean and normalize input
move_input = move_input.strip()
# If it's already UCI format (e2e4, e7e5, etc.), try it directly
if len(move_input) >= 4 and move_input[2:4].isalnum():
try:
move = chess.Move.from_uci(move_input.lower())
if move in board.legal_moves:
return move
except:
pass
# Try to parse as PGN/algebraic notation (E4, Nf3, O-O, etc.)
# Handle both uppercase and lowercase
try:
# First try the input as-is
move = board.parse_san(move_input)
if move in board.legal_moves:
return move
except:
pass
try:
# Try with different case combinations
variations = [
move_input.upper(),
move_input.lower(),
move_input.capitalize(),
]
for variation in variations:
try:
move = board.parse_san(variation)
if move in board.legal_moves:
return move
except:
continue
except:
pass
# Try some common format fixes
# Handle cases like "e4" -> "e4" (pawn moves)
if len(move_input) == 2 and move_input[0].lower() in 'abcdefgh' and move_input[1] in '12345678':
try:
move = board.parse_san(move_input.lower())
if move in board.legal_moves:
return move
except:
pass
# Handle castling notation variations
castling_variations = {
'0-0': 'O-O',
'0-0-0': 'O-O-O',
'oo': 'O-O',
'ooo': 'O-O-O',
'o-o': 'O-O',
'o-o-o': 'O-O-O',
}
lower_input = move_input.lower()
if lower_input in castling_variations:
try:
move = board.parse_san(castling_variations[lower_input])
if move in board.legal_moves:
return move
except:
pass
return None
def get_o1_move(o1, board):
"""Get the best move from o1 with proper error handling"""
try:
policy_logits, _ = o1.predict(board)
legal_moves = list(board.legal_moves)
if not legal_moves:
return None
# Method 1: Try to use policy logits if they match legal moves count
if hasattr(policy_logits, 'shape') and len(policy_logits.shape) > 1:
if policy_logits.shape[1] >= len(legal_moves):
# Map policy scores to legal moves
move_scores = []
for i, move in enumerate(legal_moves):
if i < policy_logits.shape[1]:
score = policy_logits[0, i].item()
else:
score = random.random() # Random score for unmapped moves
move_scores.append((score, move))
# Sort by score and return best move
move_scores.sort(reverse=True)
return move_scores[0][1]
# Method 2: If policy logits don't work, use a simple heuristic
# Prioritize captures, then center control, then random
scored_moves = []
for move in legal_moves:
score = 0
# Prioritize captures
if board.is_capture(move):
score += 10
# Prioritize moves to center squares
to_square = move.to_square
file = chess.square_file(to_square)
rank = chess.square_rank(to_square)
if 2 <= file <= 5 and 2 <= rank <= 5: # Center squares
score += 5
# Add some randomness
score += random.random()
scored_moves.append((score, move))
scored_moves.sort(reverse=True)
return scored_moves[0][1]
except Exception as e:
st.error(f"Error in o1 prediction: {e}")
# Fallback to random legal move
legal_moves = list(board.legal_moves)
return random.choice(legal_moves) if legal_moves else None
if "board" not in st.session_state:
st.session_state.board = chess.Board()
if "history" not in st.session_state:
st.session_state.history = []
# Load o1 with error handling
try:
o1 = load_o1()
o1_loaded = True
except Exception as e:
st.error(f"Failed to load o1: {e}")
o1 = None
o1_loaded = False
board = st.session_state.board
history = st.session_state.history
st.title("♟️ Play Chess vs o1")
if not o1_loaded:
st.warning("o1 failed to load. The app will use random moves for the AI.")
if st.button("Reset Game"):
st.session_state.board = chess.Board()
st.session_state.history = []
st.rerun()
board_placeholder = st.empty()
def render_board():
try:
last_move = board.peek() if board.move_stack else None
svg_board = chess.svg.board(board=board, lastmove=last_move, size=400)
board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True)
except Exception as e:
st.error(f"Error rendering board: {e}")
render_board()
# Game status
col1, col2 = st.columns(2)
with col1:
st.write(f"**Turn:** {'White' if board.turn == chess.WHITE else 'Black'}")
with col2:
if board.is_check():
st.write("**Check!**")
# Player move input (White)
if not board.is_game_over() and board.turn == chess.WHITE:
st.write("### Your Turn (White)")
# Show legal moves in both formats for reference
legal_moves = list(board.legal_moves)
legal_moves_uci = [move.uci() for move in legal_moves]
legal_moves_san = []
for move in legal_moves:
try:
san = board.san(move)
legal_moves_san.append(san)
except:
legal_moves_san.append(move.uci())
with st.expander("Show legal moves"):
col1, col2 = st.columns(2)
with col1:
st.write("**Algebraic notation:**")
st.write(", ".join(sorted(legal_moves_san)))
with col2:
st.write("**UCI notation:**")
st.write(", ".join(sorted(legal_moves_uci)))
user_move = st.text_input("Enter your move (e.g., E4, Nf3, e2e4, O-O):", key="move_input",
help="You can use algebraic notation (E4, Nf3) or UCI notation (e2e4). Case doesn't matter!")
col1, col2 = st.columns(2)
with col1:
if st.button("Submit Move"):
if user_move:
parsed_move = parse_move_input(user_move, board)
if parsed_move:
try:
board.push(parsed_move)
history.append(parsed_move.uci())
st.success(f"You played: {board.san(parsed_move)} ({parsed_move.uci()})")
st.rerun()
except Exception as e:
st.error(f"Error making move: {e}")
else:
st.warning(f"Invalid move: '{user_move}'. Please check the legal moves above.")
else:
st.warning("Please enter a move.")
with col2:
if st.button("Random Move"):
if legal_moves:
random_move = random.choice(legal_moves)
board.push(random_move)
history.append(random_move.uci())
st.rerun()
# o1 move (Black)
if not board.is_game_over() and board.turn == chess.BLACK:
st.write("### o1's Turn (Black)")
with st.spinner("o1 is thinking..."):
try:
if o1_loaded and o1:
best_move = get_o1_move(o1, board)
else:
# Fallback to random move if o1 not loaded
legal_moves = list(board.legal_moves)
best_move = random.choice(legal_moves) if legal_moves else None
if best_move and best_move in board.legal_moves:
move_san = board.san(best_move)
board.push(best_move)
history.append(best_move.uci())
st.success(f"o1 played: {move_san} ({best_move.uci()})")
st.rerun()
else:
st.error("o1 couldn't find a valid move")
except Exception as e:
st.error(f"Error during o1 move: {e}")
# Game history
if history:
st.write("### Game History")
# Create PGN
try:
game = chess.pgn.Game()
game.headers["Event"] = "Human vs o1"
game.headers["White"] = "Human"
game.headers["Black"] = "o1"
node = game
temp_board = chess.Board()
for uci in history:
move = chess.Move.from_uci(uci)
if move in temp_board.legal_moves:
node = node.add_main_variation(move)
temp_board.push(move)
else:
st.warning(f"Invalid move in history: {uci}")
break
st.code(str(game), language="pgn")
except Exception as e:
st.error(f"Error generating PGN: {e}")
# Fallback: show move list
move_pairs = []
for i in range(0, len(history), 2):
white_move = history[i]
black_move = history[i+1] if i+1 < len(history) else ""
move_pairs.append(f"{i//2 + 1}. {white_move} {black_move}")
st.code("\n".join(move_pairs))
# Game over
if board.is_game_over():
st.write("### Game Over!")
result = board.result()
outcome = board.outcome()
if result == "1-0":
st.success("White wins!")
elif result == "0-1":
st.error("Black wins!")
else:
st.info("Draw!")
st.write(f"**Result:** {result}")
st.write(f"**Termination:** {outcome.termination.name}")
if st.button("Start New Game"):
st.session_state.board = chess.Board()
st.session_state.history = []
st.rerun()