Spaces:
Running
Running
import os | |
# Fix for permission denied issue in Hugging Face Spaces | |
HF_HOME = "/tmp/huggingface" | |
os.environ["HF_HOME"] = HF_HOME | |
os.makedirs(HF_HOME, exist_ok=True) | |
import streamlit as st | |
st.set_page_config(page_title="Play Chess vs o2", layout="centered") | |
import chess | |
import chess.svg | |
import torch | |
from o2_model import O2Net, board_to_tensor | |
from o2_agent import O2Agent | |
from PIL import Image | |
import io | |
import base64 | |
import chess.pgn | |
import random | |
from huggingface_hub import hf_hub_download | |
# Hugging Face model config | |
MODEL_REPO = "FlameF0X/o2" | |
MODEL_FILENAME = "o2_agent_latest.pth" | |
MODEL_CACHE_DIR = HF_HOME | |
def ensure_model(): | |
try: | |
model_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILENAME, | |
cache_dir=MODEL_CACHE_DIR | |
) | |
return model_path | |
except Exception as e: | |
st.error(f"Error downloading model: {e}") | |
return None | |
def load_agent(): | |
try: | |
model_path = ensure_model() | |
if model_path is None or not os.path.isfile(model_path): | |
st.error(f"Model file not found at {model_path}") | |
return None | |
agent = O2Agent() | |
agent.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
agent.model.eval() | |
return agent | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
try: | |
os.remove(model_path) | |
except: | |
pass | |
return None | |
def render_svg(svg): | |
b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8') | |
return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64 | |
def parse_move_input(move_input, board): | |
if not move_input: | |
return None | |
move_input = move_input.strip() | |
# Try UCI | |
try: | |
if len(move_input) in (4, 5): | |
move = chess.Move.from_uci(move_input.lower()) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
# Try SAN | |
try: | |
move = board.parse_san(move_input) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
# Castling aliases | |
castling_variations = { | |
'0-0': 'O-O', '0-0-0': 'O-O-O', | |
'oo': 'O-O', 'ooo': 'O-O-O', | |
'o-o': 'O-O', 'o-o-o': 'O-O-O' | |
} | |
normalized = castling_variations.get(move_input.lower()) | |
if normalized: | |
try: | |
move = board.parse_san(normalized) | |
if move in board.legal_moves: | |
return move | |
except: | |
pass | |
return None | |
# --- Session state --- | |
if "board" not in st.session_state: | |
st.session_state.board = chess.Board() | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
agent = None | |
agent_loaded = False | |
with st.spinner("Loading o2 model..."): | |
agent = load_agent() | |
if agent is not None: | |
agent_loaded = True | |
st.success("o2 model loaded successfully!") | |
else: | |
st.warning("Model not loaded. Using random AI.") | |
board = st.session_state.board | |
history = st.session_state.history | |
st.title("♟️ Play Chess vs o2") | |
if not agent_loaded: | |
st.info("🎲 AI fallback: random move selection.") | |
if st.button("Reset Game"): | |
st.session_state.board = chess.Board() | |
st.session_state.history = [] | |
st.rerun() | |
col_board, col_pgn = st.columns([2, 1]) | |
with col_board: | |
board_placeholder = st.empty() | |
def render_board(): | |
try: | |
last_move = board.peek() if board.move_stack else None | |
svg_board = chess.svg.board(board=board, lastmove=last_move, size=400) | |
board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Render error: {e}") | |
render_board() | |
with col_pgn: | |
st.write("### Game History") | |
pgn_placeholder = st.empty() | |
def render_pgn(): | |
if history: | |
try: | |
game = chess.pgn.Game() | |
game.headers["Event"] = "Human vs o2" | |
game.headers["White"] = "Human" | |
game.headers["Black"] = "o2" if agent_loaded else "Random AI" | |
node = game | |
temp_board = chess.Board() | |
for uci in history: | |
move = chess.Move.from_uci(uci) | |
if move in temp_board.legal_moves: | |
node = node.add_main_variation(move) | |
temp_board.push(move) | |
pgn_placeholder.code(str(game), language="pgn") | |
except: | |
pgn_placeholder.code("\n".join(f"{i//2 + 1}. {history[i]} {history[i+1] if i+1 < len(history) else ''}" for i in range(0, len(history), 2))) | |
else: | |
pgn_placeholder.text("No moves yet") | |
render_pgn() | |
if not board.is_game_over() and board.turn == chess.WHITE: | |
st.write("### Your Turn (White)") | |
legal_moves = list(board.legal_moves) | |
legal_moves_uci = [m.uci() for m in legal_moves] | |
legal_moves_san = [board.san(m) for m in legal_moves] | |
with st.expander("Show legal moves"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("**SAN:**", ", ".join(legal_moves_san)) | |
with col2: | |
st.write("**UCI:**", ", ".join(legal_moves_uci)) | |
user_move = st.text_input("Enter your move:", key="move_input") | |
if st.button("Submit Move"): | |
move = parse_move_input(user_move, board) | |
if move: | |
board.push(move) | |
history.append(move.uci()) | |
render_board() | |
render_pgn() | |
st.rerun() | |
else: | |
st.warning("Invalid move. Use standard algebraic notation (e.g., `e4`, `Nf3`, `O-O`) or UCI (e.g., `e2e4`).") | |
if st.button("Random Move"): | |
move = random.choice(legal_moves) | |
board.push(move) | |
history.append(move.uci()) | |
render_board() | |
render_pgn() | |
st.rerun() | |
if not board.is_game_over() and board.turn == chess.BLACK: | |
st.write("### o2's Turn (Black)") | |
with st.spinner("o2 is thinking..."): | |
try: | |
if agent_loaded: | |
move = agent.select_move(board, use_mcts=True, simulations=30, temperature=1.2 if len(history) < 20 else 0.0) | |
else: | |
move = random.choice(list(board.legal_moves)) | |
if move: | |
san_move = board.san(move) # FIX: Get SAN before pushing | |
board.push(move) | |
history.append(move.uci()) | |
st.success(f"o2 played: {san_move}") | |
st.rerun() | |
except Exception as e: | |
st.error(f"AI error: {e}") | |
if board.is_game_over(): | |
result = board.result() | |
outcome = board.outcome() | |
st.write("### Game Over!") | |
if result == "1-0": | |
st.success("You win!") | |
elif result == "0-1": | |
st.error("o2 wins!") | |
else: | |
st.info("Draw!") | |
st.write(f"**Result:** {result}") | |
st.write(f"**Termination:** {outcome.termination.name}") | |
if st.button("New Game"): | |
st.session_state.board = chess.Board() | |
st.session_state.history = [] | |
st.rerun() | |