| import torch
|
| import torchvision.transforms as transforms
|
| import torchvision.models.detection as detection
|
| from PIL import Image
|
| import cv2
|
| import numpy as np
|
|
|
|
|
| CLASS_NAMES = {
|
| 0: "vehicle", 1: "bicycle", 2: "bus", 3: "car", 4: "lorry"
|
| }
|
| NUM_CLASSES = len(CLASS_NAMES) + 1
|
|
|
|
|
| stop_detection = False
|
|
|
|
|
| def load_model(model_path, num_classes):
|
| model = detection.fasterrcnn_resnet50_fpn(weights=None, num_classes=num_classes)
|
|
|
|
|
| checkpoint = torch.load(model_path, map_location="cpu")
|
| model.load_state_dict(checkpoint)
|
|
|
| model.eval()
|
| return model
|
|
|
|
|
| def preprocess_image(image):
|
| transform = transforms.Compose([transforms.ToTensor()])
|
| return transform(image).unsqueeze(0)
|
|
|
|
|
| def run_inference_on_image(image_path, model):
|
| img = Image.open(image_path).convert("RGB")
|
| img_tensor = preprocess_image(img)
|
|
|
| with torch.no_grad():
|
| prediction = model(img_tensor)[0]
|
|
|
| draw_predictions(image_path, prediction)
|
|
|
|
|
| def draw_predictions(image_path, prediction):
|
| image = cv2.imread(image_path)
|
| boxes, scores, labels = prediction['boxes'], prediction['scores'], prediction['labels']
|
|
|
| for i in range(len(boxes)):
|
| if scores[i] > 0.5:
|
| x1, y1, x2, y2 = map(int, boxes[i].tolist())
|
| class_id = labels[i].item()
|
| class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
|
| cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| cv2.putText(image, f"{class_name}: {scores[i]:.2f}", (x1, y1 - 10),
|
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
| cv2.imshow("Detection Result", image)
|
| cv2.waitKey(0)
|
| cv2.destroyAllWindows()
|
|
|
|
|
| def stop_real_time_detection(event, x, y, flags, param):
|
| global stop_detection
|
| if event == cv2.EVENT_LBUTTONDOWN:
|
| stop_detection = True
|
|
|
|
|
| def real_time_detection(model):
|
| global stop_detection
|
| stop_detection = False
|
|
|
| cap = cv2.VideoCapture(0)
|
| if not cap.isOpened():
|
| print("❌ Error: Could not open webcam.")
|
| return
|
|
|
| print("🎥 Starting real-time object detection. Click the window to close.")
|
|
|
| cv2.namedWindow("Real-time Detection")
|
| cv2.setMouseCallback("Real-time Detection", stop_real_time_detection)
|
|
|
| while True:
|
| ret, frame = cap.read()
|
| if not ret:
|
| print("❌ Error: Failed to capture frame.")
|
| break
|
|
|
| img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| img_tensor = preprocess_image(img)
|
|
|
| with torch.no_grad():
|
| prediction = model(img_tensor)[0]
|
|
|
| draw_predictions_live(frame, prediction)
|
|
|
| cv2.imshow("Real-time Detection", frame)
|
|
|
|
|
| if cv2.waitKey(1) & 0xFF == ord('q'):
|
| break
|
|
|
| if stop_detection:
|
| print("🛑 Stopping real-time detection...")
|
| break
|
|
|
| cap.release()
|
| cv2.destroyAllWindows()
|
| cv2.waitKey(1)
|
|
|
|
|
| def draw_predictions_live(frame, prediction):
|
| boxes, scores, labels = prediction['boxes'], prediction['scores'], prediction['labels']
|
|
|
| for i in range(len(boxes)):
|
| if scores[i] > 0.5:
|
| x1, y1, x2, y2 = map(int, boxes[i].tolist())
|
| class_id = labels[i].item()
|
| class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
|
| cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| cv2.putText(frame, f"{class_name}: {scores[i]:.2f}", (x1, y1 - 10),
|
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
|
|
| if __name__ == "__main__":
|
| model = load_model("models/fasterrcnn_model.pth", NUM_CLASSES)
|
|
|
| while True:
|
| print("\nOptions:")
|
| print("1 - Run object detection on an image")
|
| print("2 - Run real-time object detection (Click to exit)")
|
| print("q - Quit")
|
|
|
| choice = input("Enter your choice: ").strip().lower()
|
|
|
| if choice == "1":
|
| image_path = input("Enter the path of the test image: ").strip()
|
| run_inference_on_image(image_path, model)
|
| elif choice == "2":
|
| real_time_detection(model)
|
| elif choice == "q":
|
| print("👋 Exiting program.")
|
| break
|
| else:
|
| print("⚠️ Invalid choice! Please enter '1', '2', or 'q'.")
|
|
|