Play-with-o2 / src /streamlit_app.py
FlameF0X's picture
Update src/streamlit_app.py
af8e4a4 verified
import os
# Fix for permission denied issue in Hugging Face Spaces
HF_HOME = "/tmp/huggingface"
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
import streamlit as st
st.set_page_config(page_title="Play Chess vs o2", layout="centered")
import chess
import chess.svg
import torch
from o2_model import O2Net, board_to_tensor
from o2_agent import O2Agent
from PIL import Image
import io
import base64
import chess.pgn
import random
from huggingface_hub import hf_hub_download
# Hugging Face model config
MODEL_REPO = "FlameF0X/o2"
MODEL_FILENAME = "o2_agent_latest.pth"
MODEL_CACHE_DIR = HF_HOME
def ensure_model():
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
cache_dir=MODEL_CACHE_DIR
)
return model_path
except Exception as e:
st.error(f"Error downloading model: {e}")
return None
@st.cache_resource
def load_agent():
try:
model_path = ensure_model()
if model_path is None or not os.path.isfile(model_path):
st.error(f"Model file not found at {model_path}")
return None
agent = O2Agent()
agent.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
agent.model.eval()
return agent
except Exception as e:
st.error(f"Error loading model: {e}")
try:
os.remove(model_path)
except:
pass
return None
def render_svg(svg):
b64 = base64.b64encode(svg.encode('utf-8')).decode('utf-8')
return f"<img src='data:image/svg+xml;base64,{b64}'/>", b64
def parse_move_input(move_input, board):
if not move_input:
return None
move_input = move_input.strip()
# Try UCI
try:
if len(move_input) in (4, 5):
move = chess.Move.from_uci(move_input.lower())
if move in board.legal_moves:
return move
except:
pass
# Try SAN
try:
move = board.parse_san(move_input)
if move in board.legal_moves:
return move
except:
pass
# Castling aliases
castling_variations = {
'0-0': 'O-O', '0-0-0': 'O-O-O',
'oo': 'O-O', 'ooo': 'O-O-O',
'o-o': 'O-O', 'o-o-o': 'O-O-O'
}
normalized = castling_variations.get(move_input.lower())
if normalized:
try:
move = board.parse_san(normalized)
if move in board.legal_moves:
return move
except:
pass
return None
# --- Session state ---
if "board" not in st.session_state:
st.session_state.board = chess.Board()
if "history" not in st.session_state:
st.session_state.history = []
agent = None
agent_loaded = False
with st.spinner("Loading o2 model..."):
agent = load_agent()
if agent is not None:
agent_loaded = True
st.success("o2 model loaded successfully!")
else:
st.warning("Model not loaded. Using random AI.")
board = st.session_state.board
history = st.session_state.history
st.title("♟️ Play Chess vs o2")
if not agent_loaded:
st.info("🎲 AI fallback: random move selection.")
if st.button("Reset Game"):
st.session_state.board = chess.Board()
st.session_state.history = []
st.rerun()
col_board, col_pgn = st.columns([2, 1])
with col_board:
board_placeholder = st.empty()
def render_board():
try:
last_move = board.peek() if board.move_stack else None
svg_board = chess.svg.board(board=board, lastmove=last_move, size=400)
board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True)
except Exception as e:
st.error(f"Render error: {e}")
render_board()
with col_pgn:
st.write("### Game History")
pgn_placeholder = st.empty()
def render_pgn():
if history:
try:
game = chess.pgn.Game()
game.headers["Event"] = "Human vs o2"
game.headers["White"] = "Human"
game.headers["Black"] = "o2" if agent_loaded else "Random AI"
node = game
temp_board = chess.Board()
for uci in history:
move = chess.Move.from_uci(uci)
if move in temp_board.legal_moves:
node = node.add_main_variation(move)
temp_board.push(move)
pgn_placeholder.code(str(game), language="pgn")
except:
pgn_placeholder.code("\n".join(f"{i//2 + 1}. {history[i]} {history[i+1] if i+1 < len(history) else ''}" for i in range(0, len(history), 2)))
else:
pgn_placeholder.text("No moves yet")
render_pgn()
if not board.is_game_over() and board.turn == chess.WHITE:
st.write("### Your Turn (White)")
legal_moves = list(board.legal_moves)
legal_moves_uci = [m.uci() for m in legal_moves]
legal_moves_san = [board.san(m) for m in legal_moves]
with st.expander("Show legal moves"):
col1, col2 = st.columns(2)
with col1:
st.write("**SAN:**", ", ".join(legal_moves_san))
with col2:
st.write("**UCI:**", ", ".join(legal_moves_uci))
user_move = st.text_input("Enter your move:", key="move_input")
if st.button("Submit Move"):
move = parse_move_input(user_move, board)
if move:
board.push(move)
history.append(move.uci())
render_board()
render_pgn()
st.rerun()
else:
st.warning("Invalid move. Use standard algebraic notation (e.g., `e4`, `Nf3`, `O-O`) or UCI (e.g., `e2e4`).")
if st.button("Random Move"):
move = random.choice(legal_moves)
board.push(move)
history.append(move.uci())
render_board()
render_pgn()
st.rerun()
if not board.is_game_over() and board.turn == chess.BLACK:
st.write("### o2's Turn (Black)")
with st.spinner("o2 is thinking..."):
try:
if agent_loaded:
move = agent.select_move(board, use_mcts=True, simulations=30, temperature=1.2 if len(history) < 20 else 0.0)
else:
move = random.choice(list(board.legal_moves))
if move:
san_move = board.san(move) # FIX: Get SAN before pushing
board.push(move)
history.append(move.uci())
st.success(f"o2 played: {san_move}")
st.rerun()
except Exception as e:
st.error(f"AI error: {e}")
if board.is_game_over():
result = board.result()
outcome = board.outcome()
st.write("### Game Over!")
if result == "1-0":
st.success("You win!")
elif result == "0-1":
st.error("o2 wins!")
else:
st.info("Draw!")
st.write(f"**Result:** {result}")
st.write(f"**Termination:** {outcome.termination.name}")
if st.button("New Game"):
st.session_state.board = chess.Board()
st.session_state.history = []
st.rerun()