DigitClassifier / app.py
vigilante2099's picture
Update app.py
d1d1898 verified
from flask import Flask, request, jsonify,send_from_directory
from flask_cors import CORS
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import base64
DEBUG_MODE = False
app = Flask(__name__)
CORS(app)
print("Loading model...")
try:
model = tf.keras.models.load_model('mnist_model.keras')
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
model = None
def preprocess_image(image_data, save_debug=False):
try:
if ',' in image_data:
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
image = image.convert('L')
image_array = np.array(image)
if save_debug:
import os
os.makedirs('debug', exist_ok=True)
Image.fromarray(image_array).save('debug/1_original.png')
print("Saved: debug/1_original.png")
threshold = 50
rows = np.any(image_array > threshold, axis=1)
cols = np.any(image_array > threshold, axis=0)
if not rows.any() or not cols.any():
print("⚠ Empty canvas detected")
blank = np.zeros((28, 28, 1), dtype=np.float32)
return blank.reshape(1, 28, 28, 1)
ymin, ymax = np.where(rows)[0][[0, -1]]
xmin, xmax = np.where(cols)[0][[0, -1]]
height = ymax - ymin
width = xmax - xmin
pad_y = int(height * 0.2)
pad_x = int(width * 0.2)
ymin = max(0, ymin - pad_y)
ymax = min(image_array.shape[0], ymax + pad_y)
xmin = max(0, xmin - pad_x)
xmax = min(image_array.shape[1], xmax + pad_x)
cropped = image_array[ymin:ymax, xmin:xmax]
if save_debug:
Image.fromarray(cropped).save('debug/2_cropped.png')
print(f"Saved: debug/2_cropped.png")
max_dim = max(cropped.shape[0], cropped.shape[1])
square = np.zeros((max_dim, max_dim), dtype=np.uint8)
y_offset = (max_dim - cropped.shape[0]) // 2
x_offset = (max_dim - cropped.shape[1]) // 2
square[y_offset:y_offset+cropped.shape[0], x_offset:x_offset+cropped.shape[1]] = cropped
if save_debug:
Image.fromarray(square).save('debug/3_squared.png')
print("βœ“ Saved: debug/3_squared.png")
resized = Image.fromarray(square).resize((28, 28), Image.Resampling.LANCZOS)
resized_array = np.array(resized)
if save_debug:
resized.save('debug/4_resized_28x28.png')
print("Saved: debug/4_resized_28x28.png")
normalized = resized_array.astype('float32') / 255.0
if save_debug:
final_img = (normalized * 255).astype(np.uint8)
Image.fromarray(final_img).save('debug/5_final_normalized.png')
print("βœ“ Saved: debug/5_final_normalized.png")
print(f" Min: {normalized.min():.3f}, Max: {normalized.max():.3f}, Mean: {normalized.mean():.3f}")
final = normalized.reshape(1, 28, 28, 1)
return final
except Exception as e:
print(f"❌ Preprocessing error: {e}")
import traceback
traceback.print_exc()
return None
@app.route('/')
def index():
return send_from_directory('.', 'index.html')
@app.route('/health')
def health():
return jsonify({
'status': 'healthy',
'model_loaded': model is not None
})
@app.route('/predict', methods=['POST', 'OPTIONS'])
def predict():
if request.method == 'OPTIONS':
response = jsonify({'status': 'ok'})
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type')
response.headers.add('Access-Control-Allow-Methods', 'POST')
return response
if model is None:
return jsonify({'error': 'Model not loaded'}), 500
try:
data = request.get_json()
if not data or 'image' not in data:
return jsonify({'error': 'No image data'}), 400
processed = preprocess_image(data['image'], save_debug=DEBUG_MODE)
if processed is None:
return jsonify({'error': 'Image processing failed'}), 400
predictions = model.predict(processed, verbose=0)
probs = predictions[0]
predicted_digit = int(np.argmax(probs))
confidence = float(probs[predicted_digit])
result = {
'predicted_digit': predicted_digit,
'confidence': confidence,
'probabilities': {str(i): float(probs[i]) for i in range(10)}
}
print(f"Prediction: {predicted_digit} (confidence: {confidence:.2%})")
return jsonify(result)
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
if __name__ == "__main__":
import os
print("\n" + "="*60)
print("MNIST Digit Recognition API")
print("="*60)
print(f"\nModel loaded: {model is not None}")
port = int(os.environ.get("PORT", 7860))
print(f"\nStarting server on port {port}")
print("\nEndpoints:")
print(" GET / - Status")
print(" GET /health - Health check")
print(" POST /predict - Predict digit")
print("\n" + "="*60 + "\n")
app.run(host="0.0.0.0", port=port)