File size: 11,003 Bytes
e7b67e8
4577350
f0901a9
a34e340
f0901a9
a34e340
6661b61
0248b5f
6f1ce51
f162225
b285220
f162225
2868796
 
6661b61
b285220
f162225
e386b0b
6661b61
2868796
6661b61
 
 
2868796
6661b61
f162225
 
 
 
4577350
f162225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f1ce51
f162225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f1ce51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f162225
6f1ce51
 
 
 
0248b5f
f0901a9
a34e340
 
4560539
f162225
6f1ce51
f162225
 
6f1ce51
f162225
 
 
6f1ce51
4560539
a34e340
 
f162225
f0901a9
f162225
 
6f1ce51
0248b5f
6f1ce51
 
 
4577350
e386b0b
 
 
6f1ce51
 
 
 
 
 
e386b0b
 
f0901a9
6f1ce51
 
 
 
 
 
 
 
 
a34e340
6f1ce51
 
f162225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f1ce51
f162225
 
6f1ce51
 
 
 
f162225
 
 
 
 
 
 
 
 
 
 
 
6f1ce51
f162225
6f1ce51
 
 
 
 
 
 
 
1ba4af6
f162225
a34e340
f162225
 
6f1ce51
f162225
 
6f1ce51
f162225
6f1ce51
 
 
 
f162225
6f1ce51
 
f162225
6f1ce51
 
f162225
6f1ce51
 
f162225
4577350
6f1ce51
a34e340
6f1ce51
 
 
 
 
f162225
6f1ce51
f162225
6f1ce51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4560539
6f1ce51
f0901a9
6f1ce51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import streamlit as st
import chess
import chess.svg
import chess.pgn
import base64
import torch
import requests
from o1.agent import Agent
import random
import re

st.set_page_config(page_title="Play Chess vs o1", layout="centered")

HF_MODEL_URL = "https://huggingface.co/FlameF0X/o1/resolve/main/o1_agent.pth"

@st.cache_resource
def load_o1():
    import tempfile
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        response = requests.get(HF_MODEL_URL)
        if response.status_code != 200:
            raise RuntimeError("Failed to download model from Hugging Face.")
        tmp.write(response.content)
        model_path = tmp.name

    o1 = Agent()
    o1.model.load_state_dict(torch.load(model_path, map_location='cpu'))
    o1.model.eval()
    return o1

def parse_move_input(move_input, board):
    """Parse various move input formats and return a chess.Move object"""
    if not move_input:
        return None
    
    # Clean and normalize input
    move_input = move_input.strip()
    
    # If it's already UCI format (e2e4, e7e5, etc.), try it directly
    if len(move_input) >= 4 and move_input[2:4].isalnum():
        try:
            move = chess.Move.from_uci(move_input.lower())
            if move in board.legal_moves:
                return move
        except:
            pass
    
    # Try to parse as PGN/algebraic notation (E4, Nf3, O-O, etc.)
    # Handle both uppercase and lowercase
    try:
        # First try the input as-is
        move = board.parse_san(move_input)
        if move in board.legal_moves:
            return move
    except:
        pass
    
    try:
        # Try with different case combinations
        variations = [
            move_input.upper(),
            move_input.lower(),
            move_input.capitalize(),
        ]
        
        for variation in variations:
            try:
                move = board.parse_san(variation)
                if move in board.legal_moves:
                    return move
            except:
                continue
    except:
        pass
    
    # Try some common format fixes
    # Handle cases like "e4" -> "e4" (pawn moves)
    if len(move_input) == 2 and move_input[0].lower() in 'abcdefgh' and move_input[1] in '12345678':
        try:
            move = board.parse_san(move_input.lower())
            if move in board.legal_moves:
                return move
        except:
            pass
    
    # Handle castling notation variations
    castling_variations = {
        '0-0': 'O-O',
        '0-0-0': 'O-O-O',
        'oo': 'O-O',
        'ooo': 'O-O-O',
        'o-o': 'O-O',
        'o-o-o': 'O-O-O',
    }
    
    lower_input = move_input.lower()
    if lower_input in castling_variations:
        try:
            move = board.parse_san(castling_variations[lower_input])
            if move in board.legal_moves:
                return move
        except:
            pass
    
    return None

def get_o1_move(o1, board):
    """Get the best move from o1 with proper error handling"""
    try:
        policy_logits, _ = o1.predict(board)
        legal_moves = list(board.legal_moves)
        
        if not legal_moves:
            return None
            
        # Method 1: Try to use policy logits if they match legal moves count
        if hasattr(policy_logits, 'shape') and len(policy_logits.shape) > 1:
            if policy_logits.shape[1] >= len(legal_moves):
                # Map policy scores to legal moves
                move_scores = []
                for i, move in enumerate(legal_moves):
                    if i < policy_logits.shape[1]:
                        score = policy_logits[0, i].item()
                    else:
                        score = random.random()  # Random score for unmapped moves
                    move_scores.append((score, move))
                
                # Sort by score and return best move
                move_scores.sort(reverse=True)
                return move_scores[0][1]
        
        # Method 2: If policy logits don't work, use a simple heuristic
        # Prioritize captures, then center control, then random
        scored_moves = []
        for move in legal_moves:
            score = 0
            
            # Prioritize captures
            if board.is_capture(move):
                score += 10
                
            # Prioritize moves to center squares
            to_square = move.to_square
            file = chess.square_file(to_square)
            rank = chess.square_rank(to_square)
            if 2 <= file <= 5 and 2 <= rank <= 5:  # Center squares
                score += 5
                
            # Add some randomness
            score += random.random()
            
            scored_moves.append((score, move))
        
        scored_moves.sort(reverse=True)
        return scored_moves[0][1]
        
    except Exception as e:
        st.error(f"Error in o1 prediction: {e}")
        # Fallback to random legal move
        legal_moves = list(board.legal_moves)
        return random.choice(legal_moves) if legal_moves else None

if "board" not in st.session_state:
    st.session_state.board = chess.Board()
if "history" not in st.session_state:
    st.session_state.history = []

# Load o1 with error handling
try:
    o1 = load_o1()
    o1_loaded = True
except Exception as e:
    st.error(f"Failed to load o1: {e}")
    o1 = None
    o1_loaded = False

board = st.session_state.board
history = st.session_state.history

st.title("♟️ Play Chess vs o1")

if not o1_loaded:
    st.warning("o1 failed to load. The app will use random moves for the AI.")

if st.button("Reset Game"):
    st.session_state.board = chess.Board()
    st.session_state.history = []
    st.rerun()

board_placeholder = st.empty()

def render_board():
    try:
        last_move = board.peek() if board.move_stack else None
        svg_board = chess.svg.board(board=board, lastmove=last_move, size=400)
        board_placeholder.markdown(f'<div style="display: flex; justify-content: center;">{svg_board}</div>', unsafe_allow_html=True)
    except Exception as e:
        st.error(f"Error rendering board: {e}")

render_board()

# Game status
col1, col2 = st.columns(2)
with col1:
    st.write(f"**Turn:** {'White' if board.turn == chess.WHITE else 'Black'}")
with col2:
    if board.is_check():
        st.write("**Check!**")

# Player move input (White)
if not board.is_game_over() and board.turn == chess.WHITE:
    st.write("### Your Turn (White)")
    
    # Show legal moves in both formats for reference
    legal_moves = list(board.legal_moves)
    legal_moves_uci = [move.uci() for move in legal_moves]
    legal_moves_san = []
    for move in legal_moves:
        try:
            san = board.san(move)
            legal_moves_san.append(san)
        except:
            legal_moves_san.append(move.uci())
    
    with st.expander("Show legal moves"):
        col1, col2 = st.columns(2)
        with col1:
            st.write("**Algebraic notation:**")
            st.write(", ".join(sorted(legal_moves_san)))
        with col2:
            st.write("**UCI notation:**")
            st.write(", ".join(sorted(legal_moves_uci)))
    
    user_move = st.text_input("Enter your move (e.g., E4, Nf3, e2e4, O-O):", key="move_input", 
                             help="You can use algebraic notation (E4, Nf3) or UCI notation (e2e4). Case doesn't matter!")
    
    col1, col2 = st.columns(2)
    with col1:
        if st.button("Submit Move"):
            if user_move:
                parsed_move = parse_move_input(user_move, board)
                if parsed_move:
                    try:
                        board.push(parsed_move)
                        history.append(parsed_move.uci())
                        st.success(f"You played: {board.san(parsed_move)} ({parsed_move.uci()})")
                        st.rerun()
                    except Exception as e:
                        st.error(f"Error making move: {e}")
                else:
                    st.warning(f"Invalid move: '{user_move}'. Please check the legal moves above.")
            else:
                st.warning("Please enter a move.")
    
    with col2:
        if st.button("Random Move"):
            if legal_moves:
                random_move = random.choice(legal_moves)
                board.push(random_move)
                history.append(random_move.uci())
                st.rerun()

# o1 move (Black)
if not board.is_game_over() and board.turn == chess.BLACK:
    st.write("### o1's Turn (Black)")
    with st.spinner("o1 is thinking..."):
        try:
            if o1_loaded and o1:
                best_move = get_o1_move(o1, board)
            else:
                # Fallback to random move if o1 not loaded
                legal_moves = list(board.legal_moves)
                best_move = random.choice(legal_moves) if legal_moves else None
            
            if best_move and best_move in board.legal_moves:
                move_san = board.san(best_move)
                board.push(best_move)
                history.append(best_move.uci())
                st.success(f"o1 played: {move_san} ({best_move.uci()})")
                st.rerun()
            else:
                st.error("o1 couldn't find a valid move")
                
        except Exception as e:
            st.error(f"Error during o1 move: {e}")

# Game history
if history:
    st.write("### Game History")
    
    # Create PGN
    try:
        game = chess.pgn.Game()
        game.headers["Event"] = "Human vs o1"
        game.headers["White"] = "Human"
        game.headers["Black"] = "o1"
        
        node = game
        temp_board = chess.Board()
        
        for uci in history:
            move = chess.Move.from_uci(uci)
            if move in temp_board.legal_moves:
                node = node.add_main_variation(move)
                temp_board.push(move)
            else:
                st.warning(f"Invalid move in history: {uci}")
                break
        
        st.code(str(game), language="pgn")
        
    except Exception as e:
        st.error(f"Error generating PGN: {e}")
        # Fallback: show move list
        move_pairs = []
        for i in range(0, len(history), 2):
            white_move = history[i]
            black_move = history[i+1] if i+1 < len(history) else ""
            move_pairs.append(f"{i//2 + 1}. {white_move} {black_move}")
        st.code("\n".join(move_pairs))

# Game over
if board.is_game_over():
    st.write("### Game Over!")
    result = board.result()
    outcome = board.outcome()
    
    if result == "1-0":
        st.success("White wins!")
    elif result == "0-1":
        st.error("Black wins!")
    else:
        st.info("Draw!")
    
    st.write(f"**Result:** {result}")
    st.write(f"**Termination:** {outcome.termination.name}")
    
    if st.button("Start New Game"):
        st.session_state.board = chess.Board()
        st.session_state.history = []
        st.rerun()