hoololi commited on
Commit
1093dfb
·
verified ·
1 Parent(s): 755d9e4

Upload 2 files

Browse files
Files changed (2) hide show
  1. game_engine.py +43 -9
  2. image_processing_gpu.py +136 -33
game_engine.py CHANGED
@@ -1,9 +1,9 @@
1
  # ==========================================
2
- # game_engine.py - Avec métriques OCR et dataset optimisé
3
  # ==========================================
4
 
5
  """
6
- Moteur de jeu avec tracking complet des performances OCR
7
  """
8
 
9
  import random
@@ -22,10 +22,13 @@ from image_processing_gpu import (
22
  create_thumbnail_fast,
23
  create_white_canvas,
24
  cleanup_memory,
25
- get_ocr_model_info
 
 
 
26
  )
27
 
28
- print("✅ Game Engine: Mode GPU avec métriques OCR")
29
 
30
  # Imports dataset
31
  try:
@@ -47,6 +50,37 @@ DIFFICULTY_RANGES = {
47
  "÷": {"Facile": (1, 10), "Difficile": (2, 12)}
48
  }
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def create_result_row_with_metrics(i: int, image: dict | np.ndarray | Image.Image, expected: int, operation_data: tuple[int, int, str, int]) -> dict:
51
  """Traite une image avec OCR et mesure les métriques"""
52
 
@@ -104,7 +138,7 @@ def create_result_row_with_metrics(i: int, image: dict | np.ndarray | Image.Imag
104
 
105
 
106
  class MathGame:
107
- """Moteur de jeu avec métriques OCR complètes"""
108
 
109
  def __init__(self):
110
  self.is_running = False
@@ -345,14 +379,14 @@ class MathGame:
345
 
346
  print(f"🔄 Traitement OCR avec métriques de {total_questions} images...")
347
 
348
- # Récupérer infos modèle OCR une seule fois
349
  try:
350
  ocr_model_info = get_ocr_model_info()
351
- model_name = ocr_model_info.get("model_name", "microsoft/trocr-base-handwritten")
352
  hardware = f"{ocr_model_info.get('device', 'Unknown')}-{ocr_model_info.get('gpu_name', 'Unknown')}"
353
  except Exception as e:
354
  print(f"❌ Erreur get_ocr_model_info: {e}")
355
- model_name = "microsoft/trocr-base-handwritten"
356
  hardware = "ZeroGPU-Unknown"
357
 
358
  # Boucle OCR avec métriques
@@ -407,7 +441,7 @@ class MathGame:
407
  "session_total_questions": total_questions,
408
 
409
  # Métadonnées techniques
410
- "app_version": "3.1_with_ocr_metrics",
411
  "hardware": hardware
412
  }
413
 
 
1
  # ==========================================
2
+ # game_engine.py - Avec métriques OCR et dataset optimisé + modèles commutables
3
  # ==========================================
4
 
5
  """
6
+ Moteur de jeu avec tracking complet des performances OCR et support modèles commutables
7
  """
8
 
9
  import random
 
22
  create_thumbnail_fast,
23
  create_white_canvas,
24
  cleanup_memory,
25
+ get_ocr_model_info,
26
+ get_available_models,
27
+ set_ocr_model,
28
+ get_current_model_info
29
  )
30
 
31
+ print("✅ Game Engine: Mode GPU avec métriques OCR et modèles commutables")
32
 
33
  # Imports dataset
34
  try:
 
50
  "÷": {"Facile": (1, 10), "Difficile": (2, 12)}
51
  }
52
 
53
+ def get_ocr_models_info() -> dict:
54
+ """Retourne les informations sur les modèles OCR disponibles"""
55
+ try:
56
+ available_models = get_available_models()
57
+ current_model = get_current_model_info()
58
+
59
+ return {
60
+ "available_models": available_models,
61
+ "current_model": current_model,
62
+ "model_names": list(available_models.keys())
63
+ }
64
+ except Exception as e:
65
+ print(f"❌ Erreur get_ocr_models_info: {e}")
66
+ return {
67
+ "available_models": {},
68
+ "current_model": {"model_name": "hoololi/trocr-base-handwritten-calctrainer"},
69
+ "model_names": []
70
+ }
71
+
72
+ def switch_ocr_model(model_name: str) -> str:
73
+ """Change le modèle OCR et retourne un message de statut"""
74
+ try:
75
+ success = set_ocr_model(model_name)
76
+ if success:
77
+ model_info = get_current_model_info()
78
+ return f"✅ Modèle changé vers: {model_info['display_name']}\n📍 {model_info['description']}"
79
+ else:
80
+ return f"❌ Échec du changement vers: {model_name}"
81
+ except Exception as e:
82
+ return f"❌ Erreur lors du changement: {str(e)}"
83
+
84
  def create_result_row_with_metrics(i: int, image: dict | np.ndarray | Image.Image, expected: int, operation_data: tuple[int, int, str, int]) -> dict:
85
  """Traite une image avec OCR et mesure les métriques"""
86
 
 
138
 
139
 
140
  class MathGame:
141
+ """Moteur de jeu avec métriques OCR complètes et modèles commutables"""
142
 
143
  def __init__(self):
144
  self.is_running = False
 
379
 
380
  print(f"🔄 Traitement OCR avec métriques de {total_questions} images...")
381
 
382
+ # Récupérer infos modèle OCR une seule fois - MODIFIÉ pour utiliser le nouveau système
383
  try:
384
  ocr_model_info = get_ocr_model_info()
385
+ model_name = ocr_model_info.get("model_name", "hoololi/trocr-base-handwritten-calctrainer")
386
  hardware = f"{ocr_model_info.get('device', 'Unknown')}-{ocr_model_info.get('gpu_name', 'Unknown')}"
387
  except Exception as e:
388
  print(f"❌ Erreur get_ocr_model_info: {e}")
389
+ model_name = "hoololi/trocr-base-handwritten-calctrainer"
390
  hardware = "ZeroGPU-Unknown"
391
 
392
  # Boucle OCR avec métriques
 
441
  "session_total_questions": total_questions,
442
 
443
  # Métadonnées techniques
444
+ "app_version": "3.2_with_switchable_models",
445
  "hardware": hardware
446
  }
447
 
image_processing_gpu.py CHANGED
@@ -1,5 +1,5 @@
1
  # ==========================================
2
- # image_processing_gpu.py - Version ZeroGPU simplifiée
3
  # ==========================================
4
 
5
  """
@@ -21,23 +21,137 @@ from utils import (
21
  validate_ocr_result
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Variables globales pour OCR
25
  processor = None
26
  model = None
27
- #OCR_MODEL_NAME = "TrOCR-base-handwritten"
28
- OCR_MODEL_NAME = "hoololi/trocr-base-handwritten-calctrainer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def init_ocr_model() -> bool:
31
- """Initialise TrOCR pour ZeroGPU"""
32
- global processor, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  try:
35
- print("🔄 Chargement TrOCR (ZeroGPU)...")
 
36
 
37
- #processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
38
- #model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
39
- processor = TrOCRProcessor.from_pretrained('hoololi/trocr-base-handwritten-calctrainer')
40
- model = VisionEncoderDecoderModel.from_pretrained('hoololi/trocr-base-handwritten-calctrainer')
41
 
42
  # Optimisations
43
  model.eval()
@@ -45,39 +159,26 @@ def init_ocr_model() -> bool:
45
  if torch.cuda.is_available():
46
  model = model.cuda()
47
  device_info = f"GPU ({torch.cuda.get_device_name()})"
48
- print(f"✅ TrOCR prêt sur {device_info} !")
49
  else:
50
  device_info = "CPU (ZeroGPU pas encore alloué)"
51
- print(f"⚠️ TrOCR sur CPU - {device_info}")
52
 
53
  return True
54
 
55
  except Exception as e:
56
- print(f"❌ Erreur lors du chargement TrOCR: {e}")
57
  return False
58
 
 
59
  def get_ocr_model_info() -> dict:
60
- """Retourne les informations du modèle OCR utilisé"""
61
- if torch.cuda.is_available():
62
- device = "ZeroGPU"
63
- gpu_name = torch.cuda.get_device_name()
64
- else:
65
- device = "CPU"
66
- gpu_name = "N/A"
67
-
68
- return {
69
- "model_name": OCR_MODEL_NAME,
70
- "device": device,
71
- "gpu_name": gpu_name,
72
- "framework": "HuggingFace-Transformers-ZeroGPU",
73
- "optimized_for": "accuracy",
74
- "version": "microsoft/trocr-base-handwritten"
75
- }
76
 
77
  @spaces.GPU
78
  def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
79
  """
80
- OCR avec TrOCR ZeroGPU - Version simplifiée
81
  """
82
  if image_dict is None:
83
  if debug:
@@ -87,7 +188,8 @@ def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[s
87
  try:
88
  start_time = time.time()
89
  if debug:
90
- print(" 🔄 Début OCR TrOCR ZeroGPU...")
 
91
 
92
  # Optimiser image
93
  optimized_image = optimize_image_for_ocr(image_dict, max_size=384)
@@ -133,7 +235,8 @@ def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[s
133
  if debug:
134
  total_time = time.time() - start_time
135
  device = "ZeroGPU" if torch.cuda.is_available() else "CPU"
136
- print(f" ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'")
 
137
  if dataset_image_data:
138
  print(f" 🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}")
139
 
 
1
  # ==========================================
2
+ # image_processing_gpu.py - Version ZeroGPU avec modèles OCR commutables
3
  # ==========================================
4
 
5
  """
 
21
  validate_ocr_result
22
  )
23
 
24
+ # ==========================================
25
+ # Configuration des modèles OCR disponibles
26
+ # ==========================================
27
+
28
+ AVAILABLE_OCR_MODELS = {
29
+ "microsoft/trocr-base-handwritten": {
30
+ "description": "Modèle de base Microsoft pour écriture manuscrite",
31
+ "display_name": "TrOCR Base Handwritten (Microsoft)",
32
+ "optimized_for": "general_handwriting"
33
+ },
34
+ "hoololi/trocr-base-handwritten-calctrainer": {
35
+ "description": "Modèle fine tuné pour les nombres entiers",
36
+ "display_name": "TrOCR CalcTrainer (Hoololi)",
37
+ "optimized_for": "mathematical_numbers"
38
+ }
39
+ }
40
+
41
+ # Modèle par défaut
42
+ DEFAULT_OCR_MODEL = "hoololi/trocr-base-handwritten-calctrainer"
43
+ current_ocr_model_name = DEFAULT_OCR_MODEL
44
+
45
  # Variables globales pour OCR
46
  processor = None
47
  model = None
48
+ current_loaded_model = None
49
+
50
+ def get_available_models() -> dict:
51
+ """Retourne la liste des modèles OCR disponibles"""
52
+ return AVAILABLE_OCR_MODELS
53
+
54
+ def get_current_model_info() -> dict:
55
+ """Retourne les informations du modèle OCR actuellement chargé"""
56
+ global current_ocr_model_name, current_loaded_model
57
+
58
+ model_config = AVAILABLE_OCR_MODELS.get(current_ocr_model_name, AVAILABLE_OCR_MODELS[DEFAULT_OCR_MODEL])
59
+
60
+ if torch.cuda.is_available():
61
+ device = "ZeroGPU"
62
+ gpu_name = torch.cuda.get_device_name()
63
+ else:
64
+ device = "CPU"
65
+ gpu_name = "N/A"
66
+
67
+ return {
68
+ "model_name": current_ocr_model_name,
69
+ "display_name": model_config["display_name"],
70
+ "description": model_config["description"],
71
+ "current_loaded": current_loaded_model,
72
+ "device": device,
73
+ "gpu_name": gpu_name,
74
+ "framework": "HuggingFace-Transformers-ZeroGPU",
75
+ "optimized_for": model_config["optimized_for"],
76
+ "is_loaded": processor is not None and model is not None,
77
+ # Compatibilité avec l'ancien code
78
+ "version": current_ocr_model_name
79
+ }
80
+
81
+ def set_ocr_model(model_name: str) -> bool:
82
+ """
83
+ Change le modèle OCR actif
84
+
85
+ Args:
86
+ model_name: Nom exact du modèle (ex: "microsoft/trocr-base-handwritten")
87
+
88
+ Returns:
89
+ bool: True si le changement a réussi
90
+ """
91
+ global current_ocr_model_name
92
+
93
+ if model_name not in AVAILABLE_OCR_MODELS:
94
+ print(f"❌ Modèle '{model_name}' non disponible. Modèles disponibles: {list(AVAILABLE_OCR_MODELS.keys())}")
95
+ return False
96
+
97
+ if model_name == current_ocr_model_name and processor is not None and model is not None:
98
+ print(f"✅ Modèle '{model_name}' déjà chargé")
99
+ return True
100
+
101
+ model_config = AVAILABLE_OCR_MODELS[model_name]
102
+ print(f"🔄 Changement vers le modèle: {model_config['display_name']}")
103
+ current_ocr_model_name = model_name
104
+
105
+ # Nettoyer le modèle précédent
106
+ cleanup_current_model()
107
+
108
+ # Charger le nouveau modèle
109
+ return init_ocr_model()
110
 
111
+ def cleanup_current_model():
112
+ """Nettoie le modèle actuellement chargé pour libérer la mémoire"""
113
+ global processor, model, current_loaded_model
114
+
115
+ if model is not None:
116
+ del model
117
+ model = None
118
+
119
+ if processor is not None:
120
+ del processor
121
+ processor = None
122
+
123
+ current_loaded_model = None
124
+
125
+ # Nettoyage mémoire GPU si disponible
126
+ if torch.cuda.is_available():
127
+ torch.cuda.empty_cache()
128
+
129
+ print("🧹 Modèle précédent nettoyé")
130
+
131
+ def init_ocr_model(model_name: str = None) -> bool:
132
+ """
133
+ Initialise TrOCR pour ZeroGPU avec le modèle spécifié
134
+
135
+ Args:
136
+ model_name: Nom exact du modèle à charger (optionnel, utilise current_ocr_model_name par défaut)
137
+ """
138
+ global processor, model, current_ocr_model_name, current_loaded_model
139
+
140
+ if model_name is not None:
141
+ if model_name not in AVAILABLE_OCR_MODELS:
142
+ print(f"❌ Modèle '{model_name}' non disponible")
143
+ return False
144
+ current_ocr_model_name = model_name
145
+
146
+ model_config = AVAILABLE_OCR_MODELS[current_ocr_model_name]
147
 
148
  try:
149
+ print(f"🔄 Chargement {model_config['display_name']} (ZeroGPU)...")
150
+ print(f" 📍 Modèle: {current_ocr_model_name}")
151
 
152
+ processor = TrOCRProcessor.from_pretrained(current_ocr_model_name)
153
+ model = VisionEncoderDecoderModel.from_pretrained(current_ocr_model_name)
154
+ current_loaded_model = current_ocr_model_name
 
155
 
156
  # Optimisations
157
  model.eval()
 
159
  if torch.cuda.is_available():
160
  model = model.cuda()
161
  device_info = f"GPU ({torch.cuda.get_device_name()})"
162
+ print(f"✅ {model_config['display_name']} prêt sur {device_info} !")
163
  else:
164
  device_info = "CPU (ZeroGPU pas encore alloué)"
165
+ print(f"⚠️ {model_config['display_name']} sur CPU - {device_info}")
166
 
167
  return True
168
 
169
  except Exception as e:
170
+ print(f"❌ Erreur lors du chargement {model_config['display_name']}: {e}")
171
  return False
172
 
173
+ # Alias pour compatibilité avec l'ancien code
174
  def get_ocr_model_info() -> dict:
175
+ """Alias pour get_current_model_info() - compatibilité"""
176
+ return get_current_model_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  @spaces.GPU
179
  def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
180
  """
181
+ OCR avec TrOCR ZeroGPU - Version simplifiée avec modèle commutable
182
  """
183
  if image_dict is None:
184
  if debug:
 
188
  try:
189
  start_time = time.time()
190
  if debug:
191
+ model_info = get_current_model_info()
192
+ print(f" 🔄 Début OCR {model_info['display_name']} ZeroGPU...")
193
 
194
  # Optimiser image
195
  optimized_image = optimize_image_for_ocr(image_dict, max_size=384)
 
235
  if debug:
236
  total_time = time.time() - start_time
237
  device = "ZeroGPU" if torch.cuda.is_available() else "CPU"
238
+ model_name = get_current_model_info()['display_name']
239
+ print(f" ✅ {model_name} ({device}) terminé en {total_time:.1f}s → '{final_result}'")
240
  if dataset_image_data:
241
  print(f" 🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}")
242