Spaces:
Sleeping
Sleeping
File size: 11,003 Bytes
e7b67e8 4577350 f0901a9 a34e340 f0901a9 a34e340 6661b61 0248b5f 6f1ce51 f162225 b285220 f162225 2868796 6661b61 b285220 f162225 e386b0b 6661b61 2868796 6661b61 2868796 6661b61 f162225 4577350 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 0248b5f f0901a9 a34e340 4560539 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 4560539 a34e340 f162225 f0901a9 f162225 6f1ce51 0248b5f 6f1ce51 4577350 e386b0b 6f1ce51 e386b0b f0901a9 6f1ce51 a34e340 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 1ba4af6 f162225 a34e340 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 f162225 4577350 6f1ce51 a34e340 6f1ce51 f162225 6f1ce51 f162225 6f1ce51 4560539 6f1ce51 f0901a9 6f1ce51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
import streamlit as st
import chess
import chess.svg
import chess.pgn
import base64
import torch
import requests
from o1.agent import Agent
import random
import re
st.set_page_config(page_title="Play Chess vs o1", layout="centered")
HF_MODEL_URL = "https://huggingface.co/FlameF0X/o1/resolve/main/o1_agent.pth"
@st.cache_resource
def load_o1():
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as tmp:
response = requests.get(HF_MODEL_URL)
if response.status_code != 200:
raise RuntimeError("Failed to download model from Hugging Face.")
tmp.write(response.content)
model_path = tmp.name
o1 = Agent()
o1.model.load_state_dict(torch.load(model_path, map_location='cpu'))
o1.model.eval()
return o1
def parse_move_input(move_input, board):
"""Parse various move input formats and return a chess.Move object"""
if not move_input:
return None
# Clean and normalize input
move_input = move_input.strip()
# If it's already UCI format (e2e4, e7e5, etc.), try it directly
if len(move_input) >= 4 and move_input[2:4].isalnum():
try:
move = chess.Move.from_uci(move_input.lower())
if move in board.legal_moves:
return move
except:
pass
# Try to parse as PGN/algebraic notation (E4, Nf3, O-O, etc.)
# Handle both uppercase and lowercase
try:
# First try the input as-is
move = board.parse_san(move_input)
if move in board.legal_moves:
return move
except:
pass
try:
# Try with different case combinations
variations = [
move_input.upper(),
move_input.lower(),
move_input.capitalize(),
]
for variation in variations:
try:
move = board.parse_san(variation)
if move in board.legal_moves:
return move
except:
continue
except:
pass
# Try some common format fixes
# Handle cases like "e4" -> "e4" (pawn moves)
if len(move_input) == 2 and move_input[0].lower() in 'abcdefgh' and move_input[1] in '12345678':
try:
move = board.parse_san(move_input.lower())
if move in board.legal_moves:
return move
except:
pass
# Handle castling notation variations
castling_variations = {
'0-0': 'O-O',
'0-0-0': 'O-O-O',
'oo': 'O-O',
'ooo': 'O-O-O',
'o-o': 'O-O',
'o-o-o': 'O-O-O',
}
lower_input = move_input.lower()
if lower_input in castling_variations:
try:
move = board.parse_san(castling_variations[lower_input])
if move in board.legal_moves:
return move
except:
pass
return None
def get_o1_move(o1, board):
"""Get the best move from o1 with proper error handling"""
try:
policy_logits, _ = o1.predict(board)
legal_moves = list(board.legal_moves)
if not legal_moves:
return None
# Method 1: Try to use policy logits if they match legal moves count
if hasattr(policy_logits, 'shape') and len(policy_logits.shape) > 1:
if policy_logits.shape[1] >= len(legal_moves):
# Map policy scores to legal moves
move_scores = []
for i, move in enumerate(legal_moves):
if i < policy_logits.shape[1]:
score = policy_logits[0, i].item()
else:
score = random.random() # Random score for unmapped moves
move_scores.append((score, move))
# Sort by score and return best move
move_scores.sort(reverse=True)
return move_scores[0][1]
# Method 2: If policy logits don't work, use a simple heuristic
# Prioritize captures, then center control, then random
scored_moves = []
for move in legal_moves:
score = 0
# Prioritize captures
if board.is_capture(move):
score += 10
# Prioritize moves to center squares
to_square = move.to_square
file = chess.square_file(to_square)
rank = chess.square_rank(to_square)
if 2 <= file <= 5 and 2 <= rank <= 5: # Center squares
score += 5
# Add some randomness
score += random.random()
scored_moves.append((score, move))
scored_moves.sort(reverse=True)
return scored_moves[0][1]
except Exception as e:
st.error(f"Error in o1 prediction: {e}")
# Fallback to random legal move
legal_moves = list(board.legal_moves)
return random.choice(legal_moves) if legal_moves else None
if "board" not in st.session_state:
st.session_state.board = chess.Board()
if "history" not in st.session_state:
st.session_state.history = []
# Load o1 with error handling
try:
o1 = load_o1()
o1_loaded = True
except Exception as e:
st.error(f"Failed to load o1: {e}")
o1 = None
o1_loaded = False
board = st.session_state.board
history = st.session_state.history
st.title("♟️ Play Chess vs o1")
if not o1_loaded:
st.warning("o1 failed to load. The app will use random moves for the AI.")
if st.button("Reset Game"):
st.session_state.board = chess.Board()
st.session_state.history = []
st.rerun()
board_placeholder = st.empty()
def render_board():
try:
last_move = board.peek() if board.move_stack else None
svg_board = chess.svg.board(board=board, lastmove=last_move, size=400)
board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True)
except Exception as e:
st.error(f"Error rendering board: {e}")
render_board()
# Game status
col1, col2 = st.columns(2)
with col1:
st.write(f"**Turn:** {'White' if board.turn == chess.WHITE else 'Black'}")
with col2:
if board.is_check():
st.write("**Check!**")
# Player move input (White)
if not board.is_game_over() and board.turn == chess.WHITE:
st.write("### Your Turn (White)")
# Show legal moves in both formats for reference
legal_moves = list(board.legal_moves)
legal_moves_uci = [move.uci() for move in legal_moves]
legal_moves_san = []
for move in legal_moves:
try:
san = board.san(move)
legal_moves_san.append(san)
except:
legal_moves_san.append(move.uci())
with st.expander("Show legal moves"):
col1, col2 = st.columns(2)
with col1:
st.write("**Algebraic notation:**")
st.write(", ".join(sorted(legal_moves_san)))
with col2:
st.write("**UCI notation:**")
st.write(", ".join(sorted(legal_moves_uci)))
user_move = st.text_input("Enter your move (e.g., E4, Nf3, e2e4, O-O):", key="move_input",
help="You can use algebraic notation (E4, Nf3) or UCI notation (e2e4). Case doesn't matter!")
col1, col2 = st.columns(2)
with col1:
if st.button("Submit Move"):
if user_move:
parsed_move = parse_move_input(user_move, board)
if parsed_move:
try:
board.push(parsed_move)
history.append(parsed_move.uci())
st.success(f"You played: {board.san(parsed_move)} ({parsed_move.uci()})")
st.rerun()
except Exception as e:
st.error(f"Error making move: {e}")
else:
st.warning(f"Invalid move: '{user_move}'. Please check the legal moves above.")
else:
st.warning("Please enter a move.")
with col2:
if st.button("Random Move"):
if legal_moves:
random_move = random.choice(legal_moves)
board.push(random_move)
history.append(random_move.uci())
st.rerun()
# o1 move (Black)
if not board.is_game_over() and board.turn == chess.BLACK:
st.write("### o1's Turn (Black)")
with st.spinner("o1 is thinking..."):
try:
if o1_loaded and o1:
best_move = get_o1_move(o1, board)
else:
# Fallback to random move if o1 not loaded
legal_moves = list(board.legal_moves)
best_move = random.choice(legal_moves) if legal_moves else None
if best_move and best_move in board.legal_moves:
move_san = board.san(best_move)
board.push(best_move)
history.append(best_move.uci())
st.success(f"o1 played: {move_san} ({best_move.uci()})")
st.rerun()
else:
st.error("o1 couldn't find a valid move")
except Exception as e:
st.error(f"Error during o1 move: {e}")
# Game history
if history:
st.write("### Game History")
# Create PGN
try:
game = chess.pgn.Game()
game.headers["Event"] = "Human vs o1"
game.headers["White"] = "Human"
game.headers["Black"] = "o1"
node = game
temp_board = chess.Board()
for uci in history:
move = chess.Move.from_uci(uci)
if move in temp_board.legal_moves:
node = node.add_main_variation(move)
temp_board.push(move)
else:
st.warning(f"Invalid move in history: {uci}")
break
st.code(str(game), language="pgn")
except Exception as e:
st.error(f"Error generating PGN: {e}")
# Fallback: show move list
move_pairs = []
for i in range(0, len(history), 2):
white_move = history[i]
black_move = history[i+1] if i+1 < len(history) else ""
move_pairs.append(f"{i//2 + 1}. {white_move} {black_move}")
st.code("\n".join(move_pairs))
# Game over
if board.is_game_over():
st.write("### Game Over!")
result = board.result()
outcome = board.outcome()
if result == "1-0":
st.success("White wins!")
elif result == "0-1":
st.error("Black wins!")
else:
st.info("Draw!")
st.write(f"**Result:** {result}")
st.write(f"**Termination:** {outcome.termination.name}")
if st.button("Start New Game"):
st.session_state.board = chess.Board()
st.session_state.history = []
st.rerun() |