import tensorflow as tf import numpy as np import cv2 from PIL import Image import io import base64 import os class DrowsinessDetector: def __init__(self): self.model = None self.input_shape = (64, 64, 3) def load_model(self, model_path): """Load the model from the specified path""" self.model = tf.keras.models.load_model(model_path) def preprocess_image(self, image): """Preprocess the input image""" if isinstance(image, str): # If image is a base64 string image_data = base64.b64decode(image) image = Image.open(io.BytesIO(image_data)) image = np.array(image) elif isinstance(image, bytes): # If image is raw bytes image = Image.open(io.BytesIO(image)) image = np.array(image) # Convert to RGB if needed if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) # Resize and normalize image = cv2.resize(image, self.input_shape[:2]) image = image.astype(np.float32) / 255.0 image = np.expand_dims(image, axis=0) return image def predict(self, image): """Make prediction on the input image""" if self.model is None: raise ValueError("Model not loaded. Call load_model() first.") # Preprocess the image processed_image = self.preprocess_image(image) # Make prediction prediction = self.model.predict(processed_image) # Return prediction results return { "drowsy_probability": float(prediction[0][0]), "is_drowsy": bool(prediction[0][0] > 0.5) } # Create a global instance detector = DrowsinessDetector() def load_model(): """Load the model when the API starts""" global detector detector.load_model("model_weights.h5") def predict(image): """API endpoint for prediction""" try: result = detector.predict(image) return { "status": "success", "prediction": result } except Exception as e: return { "status": "error", "message": str(e) } # For local testing if __name__ == "__main__": # Load model load_model() # Test with a sample image test_image_path = "test_image.jpg" # Replace with your test image if os.path.exists(test_image_path): with open(test_image_path, "rb") as f: image_data = f.read() result = predict(image_data) print("Prediction result:", result)