Spaces:
Running
Running
| from flask import Flask, request, jsonify,send_from_directory | |
| from flask_cors import CORS | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| DEBUG_MODE = False | |
| app = Flask(__name__) | |
| CORS(app) | |
| print("Loading model...") | |
| try: | |
| model = tf.keras.models.load_model('mnist_model.keras') | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| def preprocess_image(image_data, save_debug=False): | |
| try: | |
| if ',' in image_data: | |
| image_data = image_data.split(',')[1] | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| image = image.convert('L') | |
| image_array = np.array(image) | |
| if save_debug: | |
| import os | |
| os.makedirs('debug', exist_ok=True) | |
| Image.fromarray(image_array).save('debug/1_original.png') | |
| print("Saved: debug/1_original.png") | |
| threshold = 50 | |
| rows = np.any(image_array > threshold, axis=1) | |
| cols = np.any(image_array > threshold, axis=0) | |
| if not rows.any() or not cols.any(): | |
| print("β Empty canvas detected") | |
| blank = np.zeros((28, 28, 1), dtype=np.float32) | |
| return blank.reshape(1, 28, 28, 1) | |
| ymin, ymax = np.where(rows)[0][[0, -1]] | |
| xmin, xmax = np.where(cols)[0][[0, -1]] | |
| height = ymax - ymin | |
| width = xmax - xmin | |
| pad_y = int(height * 0.2) | |
| pad_x = int(width * 0.2) | |
| ymin = max(0, ymin - pad_y) | |
| ymax = min(image_array.shape[0], ymax + pad_y) | |
| xmin = max(0, xmin - pad_x) | |
| xmax = min(image_array.shape[1], xmax + pad_x) | |
| cropped = image_array[ymin:ymax, xmin:xmax] | |
| if save_debug: | |
| Image.fromarray(cropped).save('debug/2_cropped.png') | |
| print(f"Saved: debug/2_cropped.png") | |
| max_dim = max(cropped.shape[0], cropped.shape[1]) | |
| square = np.zeros((max_dim, max_dim), dtype=np.uint8) | |
| y_offset = (max_dim - cropped.shape[0]) // 2 | |
| x_offset = (max_dim - cropped.shape[1]) // 2 | |
| square[y_offset:y_offset+cropped.shape[0], x_offset:x_offset+cropped.shape[1]] = cropped | |
| if save_debug: | |
| Image.fromarray(square).save('debug/3_squared.png') | |
| print("β Saved: debug/3_squared.png") | |
| resized = Image.fromarray(square).resize((28, 28), Image.Resampling.LANCZOS) | |
| resized_array = np.array(resized) | |
| if save_debug: | |
| resized.save('debug/4_resized_28x28.png') | |
| print("Saved: debug/4_resized_28x28.png") | |
| normalized = resized_array.astype('float32') / 255.0 | |
| if save_debug: | |
| final_img = (normalized * 255).astype(np.uint8) | |
| Image.fromarray(final_img).save('debug/5_final_normalized.png') | |
| print("β Saved: debug/5_final_normalized.png") | |
| print(f" Min: {normalized.min():.3f}, Max: {normalized.max():.3f}, Mean: {normalized.mean():.3f}") | |
| final = normalized.reshape(1, 28, 28, 1) | |
| return final | |
| except Exception as e: | |
| print(f"β Preprocessing error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def index(): | |
| return send_from_directory('.', 'index.html') | |
| def health(): | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'model_loaded': model is not None | |
| }) | |
| def predict(): | |
| if request.method == 'OPTIONS': | |
| response = jsonify({'status': 'ok'}) | |
| response.headers.add('Access-Control-Allow-Origin', '*') | |
| response.headers.add('Access-Control-Allow-Headers', 'Content-Type') | |
| response.headers.add('Access-Control-Allow-Methods', 'POST') | |
| return response | |
| if model is None: | |
| return jsonify({'error': 'Model not loaded'}), 500 | |
| try: | |
| data = request.get_json() | |
| if not data or 'image' not in data: | |
| return jsonify({'error': 'No image data'}), 400 | |
| processed = preprocess_image(data['image'], save_debug=DEBUG_MODE) | |
| if processed is None: | |
| return jsonify({'error': 'Image processing failed'}), 400 | |
| predictions = model.predict(processed, verbose=0) | |
| probs = predictions[0] | |
| predicted_digit = int(np.argmax(probs)) | |
| confidence = float(probs[predicted_digit]) | |
| result = { | |
| 'predicted_digit': predicted_digit, | |
| 'confidence': confidence, | |
| 'probabilities': {str(i): float(probs[i]) for i in range(10)} | |
| } | |
| print(f"Prediction: {predicted_digit} (confidence: {confidence:.2%})") | |
| return jsonify(result) | |
| except Exception as e: | |
| print(f"ERROR: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == "__main__": | |
| import os | |
| print("\n" + "="*60) | |
| print("MNIST Digit Recognition API") | |
| print("="*60) | |
| print(f"\nModel loaded: {model is not None}") | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"\nStarting server on port {port}") | |
| print("\nEndpoints:") | |
| print(" GET / - Status") | |
| print(" GET /health - Health check") | |
| print(" POST /predict - Predict digit") | |
| print("\n" + "="*60 + "\n") | |
| app.run(host="0.0.0.0", port=port) | |