Spaces:
Sleeping
Sleeping
Commit
Β·
858aaf1
1
Parent(s):
04fdb24
update inference.py
Browse files- Metrics/inference_results.png +2 -2
- inference.py +18 -18
- requirements.txt +17 -0
- train.py +8 -8
Metrics/inference_results.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
inference.py
CHANGED
|
@@ -19,7 +19,7 @@ def load_model(checkpoint_path="checkpoints/best_model.pth"):
|
|
| 19 |
|
| 20 |
# Load checkpoint to the detected device
|
| 21 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 22 |
-
print(f"
|
| 23 |
print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 24 |
print(f" Loading to device: {device}")
|
| 25 |
|
|
@@ -74,25 +74,25 @@ def generate_test_captcha(text, filename, width=160, height=60):
|
|
| 74 |
image = ImageCaptcha(width=width, height=height)
|
| 75 |
filepath = os.path.join(cfg.RESULT_DIR, filename)
|
| 76 |
image.write(text, filepath)
|
| 77 |
-
print(f"
|
| 78 |
return filepath
|
| 79 |
|
| 80 |
def main():
|
| 81 |
# Setup
|
| 82 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
-
print(f"
|
| 84 |
|
| 85 |
os.makedirs(cfg.RESULT_DIR, exist_ok=True)
|
| 86 |
|
| 87 |
try:
|
| 88 |
# Load trained model
|
| 89 |
-
print("
|
| 90 |
model = load_model()
|
| 91 |
model = model.to(device)
|
| 92 |
-
print("
|
| 93 |
|
| 94 |
# Generate test CAPTCHAs
|
| 95 |
-
print("\
|
| 96 |
test_cases = []
|
| 97 |
|
| 98 |
for i in range(4):
|
|
@@ -105,7 +105,7 @@ def main():
|
|
| 105 |
test_cases.append((text, image_path, "")) # Add empty prediction slot
|
| 106 |
|
| 107 |
# Run inference
|
| 108 |
-
print("\
|
| 109 |
print("-" * 60)
|
| 110 |
print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
|
| 111 |
print("-" * 60)
|
|
@@ -128,16 +128,16 @@ def main():
|
|
| 128 |
correct_count += 1
|
| 129 |
|
| 130 |
# Display result
|
| 131 |
-
status = "
|
| 132 |
print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
|
| 133 |
|
| 134 |
except Exception as e:
|
| 135 |
-
print(f"
|
| 136 |
|
| 137 |
# Summary
|
| 138 |
print("-" * 60)
|
| 139 |
accuracy = (correct_count / len(test_cases)) * 100
|
| 140 |
-
print(f"
|
| 141 |
|
| 142 |
# Calculate individual character accuracy
|
| 143 |
total_chars = 0
|
|
@@ -154,14 +154,14 @@ def main():
|
|
| 154 |
print(f"π€ Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
|
| 155 |
|
| 156 |
if accuracy >= 80:
|
| 157 |
-
print("
|
| 158 |
elif accuracy >= 60:
|
| 159 |
-
print("
|
| 160 |
else:
|
| 161 |
-
print("
|
| 162 |
|
| 163 |
# Create and save results plot
|
| 164 |
-
print("\
|
| 165 |
try:
|
| 166 |
metrics = TrainingMetrics()
|
| 167 |
image_paths = [case[1] for case in test_cases]
|
|
@@ -173,14 +173,14 @@ def main():
|
|
| 173 |
|
| 174 |
# Plot results
|
| 175 |
metrics.plot_results(image_paths, predictions, targets)
|
| 176 |
-
print("
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
-
print(f"
|
| 180 |
|
| 181 |
except Exception as e:
|
| 182 |
-
print(f"
|
| 183 |
-
print("
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|
| 186 |
main()
|
|
|
|
| 19 |
|
| 20 |
# Load checkpoint to the detected device
|
| 21 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 22 |
+
print(f"Loaded model from epoch {checkpoint['epoch']}")
|
| 23 |
print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 24 |
print(f" Loading to device: {device}")
|
| 25 |
|
|
|
|
| 74 |
image = ImageCaptcha(width=width, height=height)
|
| 75 |
filepath = os.path.join(cfg.RESULT_DIR, filename)
|
| 76 |
image.write(text, filepath)
|
| 77 |
+
print(f"Generated test CAPTCHA: {filename}")
|
| 78 |
return filepath
|
| 79 |
|
| 80 |
def main():
|
| 81 |
# Setup
|
| 82 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
+
print(f"Using device: {device}")
|
| 84 |
|
| 85 |
os.makedirs(cfg.RESULT_DIR, exist_ok=True)
|
| 86 |
|
| 87 |
try:
|
| 88 |
# Load trained model
|
| 89 |
+
print("Loading trained model...")
|
| 90 |
model = load_model()
|
| 91 |
model = model.to(device)
|
| 92 |
+
print("Model loaded successfully!")
|
| 93 |
|
| 94 |
# Generate test CAPTCHAs
|
| 95 |
+
print("\nGenerating test CAPTCHAs...")
|
| 96 |
test_cases = []
|
| 97 |
|
| 98 |
for i in range(4):
|
|
|
|
| 105 |
test_cases.append((text, image_path, "")) # Add empty prediction slot
|
| 106 |
|
| 107 |
# Run inference
|
| 108 |
+
print("\nRunning inference...")
|
| 109 |
print("-" * 60)
|
| 110 |
print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
|
| 111 |
print("-" * 60)
|
|
|
|
| 128 |
correct_count += 1
|
| 129 |
|
| 130 |
# Display result
|
| 131 |
+
status = "CORRECT" if is_correct else "WRONG"
|
| 132 |
print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
|
| 133 |
|
| 134 |
except Exception as e:
|
| 135 |
+
print(f"Error processing {image_path}: {e}")
|
| 136 |
|
| 137 |
# Summary
|
| 138 |
print("-" * 60)
|
| 139 |
accuracy = (correct_count / len(test_cases)) * 100
|
| 140 |
+
print(f"Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)")
|
| 141 |
|
| 142 |
# Calculate individual character accuracy
|
| 143 |
total_chars = 0
|
|
|
|
| 154 |
print(f"π€ Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
|
| 155 |
|
| 156 |
if accuracy >= 80:
|
| 157 |
+
print("Excellent performance!")
|
| 158 |
elif accuracy >= 60:
|
| 159 |
+
print("Good performance!")
|
| 160 |
else:
|
| 161 |
+
print("Room for improvement...")
|
| 162 |
|
| 163 |
# Create and save results plot
|
| 164 |
+
print("\nGenerating results visualization...")
|
| 165 |
try:
|
| 166 |
metrics = TrainingMetrics()
|
| 167 |
image_paths = [case[1] for case in test_cases]
|
|
|
|
| 173 |
|
| 174 |
# Plot results
|
| 175 |
metrics.plot_results(image_paths, predictions, targets)
|
| 176 |
+
print("Results plot generated successfully!")
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
+
print(f"Warning: Could not generate plot: {e}")
|
| 180 |
|
| 181 |
except Exception as e:
|
| 182 |
+
print(f"Error: {e}")
|
| 183 |
+
print("Make sure you have a trained model in checkpoints/best_model.pth")
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|
| 186 |
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CAPTCHA OCR Project Dependencies
|
| 2 |
+
# Core ML Framework (install separately with CUDA 12.8 support)
|
| 3 |
+
# pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
|
| 4 |
+
|
| 5 |
+
# Computer Vision
|
| 6 |
+
opencv-python>=4.8.0
|
| 7 |
+
Pillow>=9.0.0
|
| 8 |
+
|
| 9 |
+
# CAPTCHA Generation
|
| 10 |
+
captcha>=0.4.0
|
| 11 |
+
|
| 12 |
+
# Data Processing
|
| 13 |
+
numpy>=1.21.0
|
| 14 |
+
pandas>=1.3.0
|
| 15 |
+
|
| 16 |
+
# Visualization and Plotting
|
| 17 |
+
matplotlib>=3.5.0
|
train.py
CHANGED
|
@@ -147,8 +147,8 @@ def main():
|
|
| 147 |
if avg_val_loss < best_val_loss:
|
| 148 |
best_val_loss = avg_val_loss
|
| 149 |
patience_counter = 0
|
| 150 |
-
print(f"
|
| 151 |
-
print(f"
|
| 152 |
|
| 153 |
# Save best model checkpoint with metadata
|
| 154 |
checkpoint = {
|
|
@@ -169,19 +169,19 @@ def main():
|
|
| 169 |
}
|
| 170 |
}
|
| 171 |
torch.save(checkpoint, "checkpoints/best_model.pth")
|
| 172 |
-
print(f"
|
| 173 |
|
| 174 |
else:
|
| 175 |
patience_counter += 1
|
| 176 |
-
print(f"
|
| 177 |
-
print(f"
|
| 178 |
|
| 179 |
# Enhanced early stopping: Check both absolute loss and ratio
|
| 180 |
if patience_counter >= patience or val_train_ratio > 3.0: # Stop if ratio > 3x
|
| 181 |
if val_train_ratio > 3.0:
|
| 182 |
-
print(f"
|
| 183 |
else:
|
| 184 |
-
print(f"
|
| 185 |
early_stop = True
|
| 186 |
break
|
| 187 |
|
|
@@ -253,7 +253,7 @@ def main():
|
|
| 253 |
}
|
| 254 |
}
|
| 255 |
torch.save(final_checkpoint, "checkpoints/final_model.pth")
|
| 256 |
-
print(f"
|
| 257 |
|
| 258 |
print("\nGenerating training metrics and plots...")
|
| 259 |
os.makedirs("Metrics", exist_ok=True)
|
|
|
|
| 147 |
if avg_val_loss < best_val_loss:
|
| 148 |
best_val_loss = avg_val_loss
|
| 149 |
patience_counter = 0
|
| 150 |
+
print(f" New best validation loss: {best_val_loss:.4f}")
|
| 151 |
+
print(f" Val/Train ratio: {val_train_ratio:.3f}")
|
| 152 |
|
| 153 |
# Save best model checkpoint with metadata
|
| 154 |
checkpoint = {
|
|
|
|
| 169 |
}
|
| 170 |
}
|
| 171 |
torch.save(checkpoint, "checkpoints/best_model.pth")
|
| 172 |
+
print(f" Best model saved to checkpoints/best_model.pth")
|
| 173 |
|
| 174 |
else:
|
| 175 |
patience_counter += 1
|
| 176 |
+
print(f" No improvement for {patience_counter} epochs")
|
| 177 |
+
print(f" Val/Train ratio: {val_train_ratio:.3f}")
|
| 178 |
|
| 179 |
# Enhanced early stopping: Check both absolute loss and ratio
|
| 180 |
if patience_counter >= patience or val_train_ratio > 3.0: # Stop if ratio > 3x
|
| 181 |
if val_train_ratio > 3.0:
|
| 182 |
+
print(f" Early stopping triggered! Val/Train ratio too high: {val_train_ratio:.3f}")
|
| 183 |
else:
|
| 184 |
+
print(f" Early stopping triggered! No improvement for {patience} epochs")
|
| 185 |
early_stop = True
|
| 186 |
break
|
| 187 |
|
|
|
|
| 253 |
}
|
| 254 |
}
|
| 255 |
torch.save(final_checkpoint, "checkpoints/final_model.pth")
|
| 256 |
+
print(f"Final model saved to checkpoints/final_model.pth")
|
| 257 |
|
| 258 |
print("\nGenerating training metrics and plots...")
|
| 259 |
os.makedirs("Metrics", exist_ok=True)
|