mohakapoor commited on
Commit
858aaf1
Β·
1 Parent(s): 04fdb24

update inference.py

Browse files
Files changed (4) hide show
  1. Metrics/inference_results.png +2 -2
  2. inference.py +18 -18
  3. requirements.txt +17 -0
  4. train.py +8 -8
Metrics/inference_results.png CHANGED

Git LFS Details

  • SHA256: 3f354ee931ae653ed9821adbfb33c715ad310aad312064770238e879579ef078
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB

Git LFS Details

  • SHA256: 93244e9da7d2a23effdba1e9580ff7bedff2b10ea1e814e4dc97e323939b5c4a
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
inference.py CHANGED
@@ -19,7 +19,7 @@ def load_model(checkpoint_path="checkpoints/best_model.pth"):
19
 
20
  # Load checkpoint to the detected device
21
  checkpoint = torch.load(checkpoint_path, map_location=device)
22
- print(f"βœ… Loaded model from epoch {checkpoint['epoch']}")
23
  print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
24
  print(f" Loading to device: {device}")
25
 
@@ -74,25 +74,25 @@ def generate_test_captcha(text, filename, width=160, height=60):
74
  image = ImageCaptcha(width=width, height=height)
75
  filepath = os.path.join(cfg.RESULT_DIR, filename)
76
  image.write(text, filepath)
77
- print(f"πŸ“Έ Generated test CAPTCHA: {filename}")
78
  return filepath
79
 
80
  def main():
81
  # Setup
82
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
- print(f"πŸš€ Using device: {device}")
84
 
85
  os.makedirs(cfg.RESULT_DIR, exist_ok=True)
86
 
87
  try:
88
  # Load trained model
89
- print("πŸ“₯ Loading trained model...")
90
  model = load_model()
91
  model = model.to(device)
92
- print("βœ… Model loaded successfully!")
93
 
94
  # Generate test CAPTCHAs
95
- print("\n🎯 Generating test CAPTCHAs...")
96
  test_cases = []
97
 
98
  for i in range(4):
@@ -105,7 +105,7 @@ def main():
105
  test_cases.append((text, image_path, "")) # Add empty prediction slot
106
 
107
  # Run inference
108
- print("\nπŸ” Running inference...")
109
  print("-" * 60)
110
  print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
111
  print("-" * 60)
@@ -128,16 +128,16 @@ def main():
128
  correct_count += 1
129
 
130
  # Display result
131
- status = "βœ…" if is_correct else "❌"
132
  print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
133
 
134
  except Exception as e:
135
- print(f"❌ Error processing {image_path}: {e}")
136
 
137
  # Summary
138
  print("-" * 60)
139
  accuracy = (correct_count / len(test_cases)) * 100
140
- print(f"πŸ“Š Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)")
141
 
142
  # Calculate individual character accuracy
143
  total_chars = 0
@@ -154,14 +154,14 @@ def main():
154
  print(f"πŸ”€ Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
155
 
156
  if accuracy >= 80:
157
- print("πŸŽ‰ Excellent performance!")
158
  elif accuracy >= 60:
159
- print("πŸ‘ Good performance!")
160
  else:
161
- print("πŸ€” Room for improvement...")
162
 
163
  # Create and save results plot
164
- print("\nπŸ“Š Generating results visualization...")
165
  try:
166
  metrics = TrainingMetrics()
167
  image_paths = [case[1] for case in test_cases]
@@ -173,14 +173,14 @@ def main():
173
 
174
  # Plot results
175
  metrics.plot_results(image_paths, predictions, targets)
176
- print("βœ… Results plot generated successfully!")
177
 
178
  except Exception as e:
179
- print(f"⚠️ Warning: Could not generate plot: {e}")
180
 
181
  except Exception as e:
182
- print(f"❌ Error: {e}")
183
- print("πŸ’‘ Make sure you have a trained model in checkpoints/best_model.pth")
184
 
185
  if __name__ == "__main__":
186
  main()
 
19
 
20
  # Load checkpoint to the detected device
21
  checkpoint = torch.load(checkpoint_path, map_location=device)
22
+ print(f"Loaded model from epoch {checkpoint['epoch']}")
23
  print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
24
  print(f" Loading to device: {device}")
25
 
 
74
  image = ImageCaptcha(width=width, height=height)
75
  filepath = os.path.join(cfg.RESULT_DIR, filename)
76
  image.write(text, filepath)
77
+ print(f"Generated test CAPTCHA: {filename}")
78
  return filepath
79
 
80
  def main():
81
  # Setup
82
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ print(f"Using device: {device}")
84
 
85
  os.makedirs(cfg.RESULT_DIR, exist_ok=True)
86
 
87
  try:
88
  # Load trained model
89
+ print("Loading trained model...")
90
  model = load_model()
91
  model = model.to(device)
92
+ print("Model loaded successfully!")
93
 
94
  # Generate test CAPTCHAs
95
+ print("\nGenerating test CAPTCHAs...")
96
  test_cases = []
97
 
98
  for i in range(4):
 
105
  test_cases.append((text, image_path, "")) # Add empty prediction slot
106
 
107
  # Run inference
108
+ print("\nRunning inference...")
109
  print("-" * 60)
110
  print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
111
  print("-" * 60)
 
128
  correct_count += 1
129
 
130
  # Display result
131
+ status = "CORRECT" if is_correct else "WRONG"
132
  print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
133
 
134
  except Exception as e:
135
+ print(f"Error processing {image_path}: {e}")
136
 
137
  # Summary
138
  print("-" * 60)
139
  accuracy = (correct_count / len(test_cases)) * 100
140
+ print(f"Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)")
141
 
142
  # Calculate individual character accuracy
143
  total_chars = 0
 
154
  print(f"πŸ”€ Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
155
 
156
  if accuracy >= 80:
157
+ print("Excellent performance!")
158
  elif accuracy >= 60:
159
+ print("Good performance!")
160
  else:
161
+ print("Room for improvement...")
162
 
163
  # Create and save results plot
164
+ print("\nGenerating results visualization...")
165
  try:
166
  metrics = TrainingMetrics()
167
  image_paths = [case[1] for case in test_cases]
 
173
 
174
  # Plot results
175
  metrics.plot_results(image_paths, predictions, targets)
176
+ print("Results plot generated successfully!")
177
 
178
  except Exception as e:
179
+ print(f"Warning: Could not generate plot: {e}")
180
 
181
  except Exception as e:
182
+ print(f"Error: {e}")
183
+ print("Make sure you have a trained model in checkpoints/best_model.pth")
184
 
185
  if __name__ == "__main__":
186
  main()
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CAPTCHA OCR Project Dependencies
2
+ # Core ML Framework (install separately with CUDA 12.8 support)
3
+ # pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
4
+
5
+ # Computer Vision
6
+ opencv-python>=4.8.0
7
+ Pillow>=9.0.0
8
+
9
+ # CAPTCHA Generation
10
+ captcha>=0.4.0
11
+
12
+ # Data Processing
13
+ numpy>=1.21.0
14
+ pandas>=1.3.0
15
+
16
+ # Visualization and Plotting
17
+ matplotlib>=3.5.0
train.py CHANGED
@@ -147,8 +147,8 @@ def main():
147
  if avg_val_loss < best_val_loss:
148
  best_val_loss = avg_val_loss
149
  patience_counter = 0
150
- print(f" 🎯 New best validation loss: {best_val_loss:.4f}")
151
- print(f" πŸ“Š Val/Train ratio: {val_train_ratio:.3f}")
152
 
153
  # Save best model checkpoint with metadata
154
  checkpoint = {
@@ -169,19 +169,19 @@ def main():
169
  }
170
  }
171
  torch.save(checkpoint, "checkpoints/best_model.pth")
172
- print(f" πŸ’Ύ Best model saved to checkpoints/best_model.pth")
173
 
174
  else:
175
  patience_counter += 1
176
- print(f" ⚠️ No improvement for {patience_counter} epochs")
177
- print(f" πŸ“Š Val/Train ratio: {val_train_ratio:.3f}")
178
 
179
  # Enhanced early stopping: Check both absolute loss and ratio
180
  if patience_counter >= patience or val_train_ratio > 3.0: # Stop if ratio > 3x
181
  if val_train_ratio > 3.0:
182
- print(f" πŸ›‘ Early stopping triggered! Val/Train ratio too high: {val_train_ratio:.3f}")
183
  else:
184
- print(f" πŸ›‘ Early stopping triggered! No improvement for {patience} epochs")
185
  early_stop = True
186
  break
187
 
@@ -253,7 +253,7 @@ def main():
253
  }
254
  }
255
  torch.save(final_checkpoint, "checkpoints/final_model.pth")
256
- print(f"πŸ’Ύ Final model saved to checkpoints/final_model.pth")
257
 
258
  print("\nGenerating training metrics and plots...")
259
  os.makedirs("Metrics", exist_ok=True)
 
147
  if avg_val_loss < best_val_loss:
148
  best_val_loss = avg_val_loss
149
  patience_counter = 0
150
+ print(f" New best validation loss: {best_val_loss:.4f}")
151
+ print(f" Val/Train ratio: {val_train_ratio:.3f}")
152
 
153
  # Save best model checkpoint with metadata
154
  checkpoint = {
 
169
  }
170
  }
171
  torch.save(checkpoint, "checkpoints/best_model.pth")
172
+ print(f" Best model saved to checkpoints/best_model.pth")
173
 
174
  else:
175
  patience_counter += 1
176
+ print(f" No improvement for {patience_counter} epochs")
177
+ print(f" Val/Train ratio: {val_train_ratio:.3f}")
178
 
179
  # Enhanced early stopping: Check both absolute loss and ratio
180
  if patience_counter >= patience or val_train_ratio > 3.0: # Stop if ratio > 3x
181
  if val_train_ratio > 3.0:
182
+ print(f" Early stopping triggered! Val/Train ratio too high: {val_train_ratio:.3f}")
183
  else:
184
+ print(f" Early stopping triggered! No improvement for {patience} epochs")
185
  early_stop = True
186
  break
187
 
 
253
  }
254
  }
255
  torch.save(final_checkpoint, "checkpoints/final_model.pth")
256
+ print(f"Final model saved to checkpoints/final_model.pth")
257
 
258
  print("\nGenerating training metrics and plots...")
259
  os.makedirs("Metrics", exist_ok=True)