nathbns commited on
Commit
41c1ae6
·
verified ·
1 Parent(s): 3f1dd27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import chess
3
  import chess.svg
4
  import io
 
5
  import numpy as np
6
  import cv2
7
  import os
@@ -10,11 +11,22 @@ from pathlib import Path
10
 
11
  # Import EXACT SAME functions from main.py
12
  from preprocess import preprocess_image
13
- from train_tensorflow import create_model
14
-
15
- PIECES = ['Empty', 'Rook_White', 'Rook_Black', 'Knight_White', 'Knight_Black', 'Bishop_White',
16
- 'Bishop_Black', 'Queen_White', 'Queen_Black', 'King_White', 'King_Black', 'Pawn_White', 'Pawn_Black']
17
- PIECES.sort()
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  LABELS = {
20
  'Empty': '.',
@@ -40,6 +52,14 @@ print("Model loaded!")
40
 
41
 
42
  def classify_image(img):
 
 
 
 
 
 
 
 
43
  y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
44
  y_pred = y_prob.argmax()
45
  return PIECES[y_pred]
 
2
  import chess
3
  import chess.svg
4
  import io
5
+ import json
6
  import numpy as np
7
  import cv2
8
  import os
 
11
 
12
  # Import EXACT SAME functions from main.py
13
  from preprocess import preprocess_image
14
+ from train import create_model
15
+
16
+ # Charger l'ordre des classes depuis le fichier généré par train.py
17
+ try:
18
+ with open('./class_indices.json', 'r') as f:
19
+ class_indices = json.load(f)
20
+ # Inverser pour avoir index -> nom
21
+ PIECES = [None] * len(class_indices)
22
+ for name, idx in class_indices.items():
23
+ PIECES[idx] = name
24
+ print(f"Ordre des classes chargé: {PIECES}")
25
+ except FileNotFoundError:
26
+ # Fallback sur ordre alphabétique si le fichier n'existe pas
27
+ PIECES = ['Bishop_Black', 'Bishop_White', 'Empty', 'King_Black', 'King_White', 'Knight_Black',
28
+ 'Knight_White', 'Pawn_Black', 'Pawn_White', 'Queen_Black', 'Queen_White', 'Rook_Black', 'Rook_White']
29
+ print(f"Fichier class_indices.json non trouvé, utilisation ordre par défaut")
30
 
31
  LABELS = {
32
  'Empty': '.',
 
52
 
53
 
54
  def classify_image(img):
55
+ '''Given an image of a single piece, classifies it into one of the classes
56
+ defined in PIECES.'''
57
+ # IMPORTANT: Normaliser l'image comme dans l'entraînement (rescale=1/255)
58
+ if img.max() > 1.0:
59
+ img = img.astype(np.float32) / 255.0
60
+ else:
61
+ img = img.astype(np.float32)
62
+
63
  y_prob = model.predict(img.reshape(1, 300, 150, 3), verbose=0)
64
  y_pred = y_prob.argmax()
65
  return PIECES[y_pred]