import gradio as gr import chess import chess.svg import io import json import numpy as np import cv2 import os import tempfile from pathlib import Path from preprocess import preprocess_image from train import create_model # On charge l'ordre des classes depuis le fichier généré par train. try: with open('./class_indices.json', 'r') as f: class_indices = json.load(f) # Inverser pour avoir index -> nom PIECES = [None] * len(class_indices) for name, idx in class_indices.items(): PIECES[idx] = name print(f"Ordre des classes chargé: {PIECES}") except FileNotFoundError: # Si jamais le fichier n'est pas load correctement ou erreur PIECES = ['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black', 'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White'] print(f"Fichier class_indices.json non trouvé, utilisation ordre par défaut") LABELS = { 'Empty': '.', 'Rook_White': 'R', 'Rook_Black': 'r', 'Knight_White': 'N', 'Knight_Black': 'n', 'Bishop_White': 'B', 'Bishop_Black': 'b', 'Queen_White': 'Q', 'Queen_Black': 'q', 'King_White': 'K', 'King_Black': 'k', 'Pawn_White': 'P', 'Pawn_Black': 'p', } # On charge notre modele print("Loading model...") model = create_model() model.load_weights('./model_weights.weights.h5') print("Model loaded!") def classify_image(img): # On donne une image d'une pièce unique, on la classifie en une seule classe definie (Son nom est PIECE) # Ici on normalise notre image comme dans notre entrainement (ici on fait un rescale=1/255) if img.max() > 1.0: img = img.astype(np.float32) / 255.0 else: img = img.astype(np.float32) y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0) y_pred = y_prob.argmax() return PIECES[y_pred] def analyze_board(img): arr = [] M = img.shape[0]//8 N = img.shape[1]//8 for y in range(M-1, img.shape[1], M): row = [] for x in range(0, img.shape[1], N): sub_img = img[max(0, y-2*M):y, x:x+N] if y-2*M < 0: sub_img = np.concatenate( (np.zeros((2*M-y, N, 3)), sub_img)) sub_img = sub_img.astype(np.uint8) piece = classify_image(sub_img) row.append(LABELS[piece]) arr.append(row) # Ajustement King-Queen detection blackKing = False whiteKing = False whitePos = (-1, -1) blackPos = (-1, -1) for i in range(8): for j in range(8): if arr[i][j] == 'K': whiteKing = True if arr[i][j] == 'k': blackKing = True if arr[i][j] == 'Q': whitePos = (i, j) if arr[i][j] == 'q': blackPos = (i, j) if not whiteKing and whitePos[0] >= 0: arr[whitePos[0]][whitePos[1]] = 'K' if not blackKing and blackPos[0] >= 0: arr[blackPos[0]][blackPos[1]] = 'k' return arr def board_to_fen(board): with io.StringIO() as s: for row in board: empty = 0 for cell in row: if cell != '.': if empty > 0: s.write(str(empty)) empty = 0 s.write(cell) else: empty += 1 if empty > 0: s.write(str(empty)) s.write('/') s.seek(s.tell() - 1) s.write(' w KQkq - 0 1') return s.getvalue() def analyze_chess_image(image_input): # Logique gradio pour notre main. if image_input is None: return "❌ No image provided", None try: # On sauvegarde temporairement with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: if isinstance(image_input, np.ndarray): cv2.imwrite(tmp.name, cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)) else: image_input.save(tmp.name) temp_path = tmp.name # preprocess_image() utilise le modele LAPS img = preprocess_image(temp_path, save=False) # EXACT SAME as main.py arr = analyze_board(img) fen = board_to_fen(arr) # On génère l'echiquier board = chess.Board(fen) board_svg = chess.svg.board(board=board, size=400) # on clean le fichier temporairement sauvegarder os.unlink(temp_path) return f"{fen}", board_svg except Exception as e: import traceback print(traceback.format_exc()) return f"Error: {str(e)}", None # Build Gradio interface with gr.Blocks(title="Chess Board picture -> FEN notation", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ♟️ YOCO: You Only Look Once Upload a chess board image to automatically detect all pieces and get the FEN notation. """) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload chess board image", type="pil") submit_btn = gr.Button("Analyze Board", size="lg", variant="primary") with gr.Column(): status_output = gr.Textbox(label="Result", interactive=False, lines=2) board_output = gr.HTML(label="Board Visualization") submit_btn.click( fn=analyze_chess_image, inputs=image_input, outputs=[status_output, board_output] ) if __name__ == "__main__": demo.launch()