| | |
| | """ |
| | Real-time strawberry detection/classification using TFLite model. |
| | Supports both binary classification (good/bad) and YOLOv8 detection. |
| | """ |
| |
|
| | import argparse |
| | import cv2 |
| | import numpy as np |
| | import tensorflow as tf |
| | from pathlib import Path |
| | import sys |
| |
|
| | def load_tflite_model(model_path): |
| | """Load TFLite model and allocate tensors.""" |
| | if not Path(model_path).exists(): |
| | raise FileNotFoundError(f"Model file not found: {model_path}") |
| | |
| | interpreter = tf.lite.Interpreter(model_path=model_path) |
| | interpreter.allocate_tensors() |
| | return interpreter |
| |
|
| | def get_model_details(interpreter): |
| | """Get input and output details of the TFLite model.""" |
| | input_details = interpreter.get_input_details() |
| | output_details = interpreter.get_output_details() |
| | return input_details, output_details |
| |
|
| | def preprocess_image(image, input_shape): |
| | """Preprocess image for model inference.""" |
| | height, width = input_shape[1:3] if len(input_shape) == 4 else input_shape[1:3] |
| | img = cv2.resize(image, (width, height)) |
| | img = img / 255.0 |
| | img = np.expand_dims(img, axis=0).astype(np.float32) |
| | return img |
| |
|
| | def run_inference(interpreter, input_details, output_details, preprocessed_img): |
| | """Run inference on preprocessed image.""" |
| | interpreter.set_tensor(input_details[0]['index'], preprocessed_img) |
| | interpreter.invoke() |
| | return interpreter.get_tensor(output_details[0]['index']) |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Real-time strawberry detection/classification') |
| | parser.add_argument('--model', type=str, default='strawberry_model.tflite', |
| | help='Path to TFLite model (default: strawberry_model.tflite)') |
| | parser.add_argument('--camera', type=int, default=0, |
| | help='Camera index (default: 0)') |
| | parser.add_argument('--threshold', type=float, default=0.5, |
| | help='Confidence threshold for binary classification (default: 0.5)') |
| | parser.add_argument('--input-size', type=int, default=224, |
| | help='Input image size (width=height) for model (default: 224)') |
| | parser.add_argument('--mode', choices=['classification', 'detection'], default='classification', |
| | help='Inference mode: classification (good/bad) or detection (YOLO)') |
| | parser.add_argument('--verbose', action='store_true', |
| | help='Print detailed inference information') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | try: |
| | interpreter = load_tflite_model(args.model) |
| | input_details, output_details = get_model_details(interpreter) |
| | input_shape = input_details[0]['shape'] |
| | if args.verbose: |
| | print(f"Model loaded: {args.model}") |
| | print(f"Input shape: {input_shape}") |
| | print(f"Output details: {output_details[0]}") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | sys.exit(1) |
| | |
| | |
| | cap = cv2.VideoCapture(args.camera) |
| | if not cap.isOpened(): |
| | print(f"Cannot open camera index {args.camera}") |
| | sys.exit(1) |
| | |
| | print(f"Starting real-time inference (mode: {args.mode})") |
| | print("Press 'q' to quit, 's' to save current frame") |
| | |
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | print("Failed to capture frame") |
| | break |
| | |
| | |
| | preprocessed = preprocess_image(frame, input_shape) |
| | |
| | |
| | predictions = run_inference(interpreter, input_details, output_details, preprocessed) |
| | |
| | |
| | if args.mode == 'classification': |
| | |
| | confidence = predictions[0][0] |
| | label = 'Good' if confidence > args.threshold else 'Bad' |
| | display_text = f'{label}: {confidence:.2f}' |
| | color = (0, 255, 0) if confidence > args.threshold else (0, 0, 255) |
| | else: |
| | |
| | display_text = 'Detection mode not yet implemented' |
| | color = (255, 255, 0) |
| | |
| | |
| | cv2.putText(frame, display_text, (10, 30), |
| | cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) |
| | cv2.imshow('Strawberry Detection', frame) |
| | |
| | key = cv2.waitKey(1) & 0xFF |
| | if key == ord('q'): |
| | break |
| | elif key == ord('s'): |
| | filename = f'capture_{cv2.getTickCount()}.jpg' |
| | cv2.imwrite(filename, frame) |
| | print(f"Frame saved as {filename}") |
| | |
| | cap.release() |
| | cv2.destroyAllWindows() |
| | print("Real-time detection stopped.") |
| |
|
| | if __name__ == '__main__': |
| | main() |