ragebhanukiran's picture
initial commit
b338c44 verified
Raw
History Blame Contribute Delete
5.11 kB
import torch
import torchvision.transforms as transforms
import torchvision.models.detection as detection
from PIL import Image
import cv2
import numpy as np
# Define class names
CLASS_NAMES = {
0: "vehicle", 1: "bicycle", 2: "bus", 3: "car", 4: "lorry"
}
NUM_CLASSES = len(CLASS_NAMES) + 1 # Include background class
# Global flag to stop detection when clicking the window
stop_detection = False
# Load Faster R-CNN model
def load_model(model_path, num_classes):
model = detection.fasterrcnn_resnet50_fpn(weights=None, num_classes=num_classes)
# Load the trained model
checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()
return model
# Transform input image
def preprocess_image(image):
transform = transforms.Compose([transforms.ToTensor()])
return transform(image).unsqueeze(0)
# Run inference on an image
def run_inference_on_image(image_path, model):
img = Image.open(image_path).convert("RGB")
img_tensor = preprocess_image(img)
with torch.no_grad():
prediction = model(img_tensor)[0]
draw_predictions(image_path, prediction)
# Draw predictions on an image
def draw_predictions(image_path, prediction):
image = cv2.imread(image_path)
boxes, scores, labels = prediction['boxes'], prediction['scores'], prediction['labels']
for i in range(len(boxes)):
if scores[i] > 0.5: # Confidence threshold
x1, y1, x2, y2 = map(int, boxes[i].tolist())
class_id = labels[i].item()
class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, f"{class_name}: {scores[i]:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.imshow("Detection Result", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
# Mouse click event function to stop real-time detection
def stop_real_time_detection(event, x, y, flags, param):
global stop_detection
if event == cv2.EVENT_LBUTTONDOWN: # Detect left mouse button click
stop_detection = True
# Real-time detection with webcam
def real_time_detection(model):
global stop_detection
stop_detection = False # Reset flag before starting detection
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("❌ Error: Could not open webcam.")
return
print("🎥 Starting real-time object detection. Click the window to close.")
cv2.namedWindow("Real-time Detection") # Name the OpenCV window
cv2.setMouseCallback("Real-time Detection", stop_real_time_detection) # Set mouse click callback
while True:
ret, frame = cap.read()
if not ret:
print("❌ Error: Failed to capture frame.")
break
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
img_tensor = preprocess_image(img)
with torch.no_grad():
prediction = model(img_tensor)[0]
draw_predictions_live(frame, prediction)
cv2.imshow("Real-time Detection", frame)
# FIX: Process window events so OpenCV does not freeze
if cv2.waitKey(1) & 0xFF == ord('q'): # Press 'q' to exit (optional)
break
if stop_detection: # If window is clicked, stop detection
print("🛑 Stopping real-time detection...")
break
cap.release()
cv2.destroyAllWindows() # FIX: Close OpenCV windows properly
cv2.waitKey(1) # FIX: Ensure window is destroyed
# Draw predictions on live video feed
def draw_predictions_live(frame, prediction):
boxes, scores, labels = prediction['boxes'], prediction['scores'], prediction['labels']
for i in range(len(boxes)):
if scores[i] > 0.5:
x1, y1, x2, y2 = map(int, boxes[i].tolist())
class_id = labels[i].item()
class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f"{class_name}: {scores[i]:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
if __name__ == "__main__":
model = load_model("models/fasterrcnn_model.pth", NUM_CLASSES)
while True:
print("\nOptions:")
print("1 - Run object detection on an image")
print("2 - Run real-time object detection (Click to exit)")
print("q - Quit")
choice = input("Enter your choice: ").strip().lower()
if choice == "1":
image_path = input("Enter the path of the test image: ").strip()
run_inference_on_image(image_path, model)
elif choice == "2":
real_time_detection(model)
elif choice == "q":
print("👋 Exiting program.")
break
else:
print("⚠️ Invalid choice! Please enter '1', '2', or 'q'.")