PIWM / src /models /contours.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
# Vehicle Detection and State Estimation using Color-Based Contour Detection
#
# This script detects vehicles in a Bird's-Eye View (BEV) image by
# isolating their specific colors (green and blue) and then analyzing
# the shapes (contours) of the colored areas. The detected states are
# then exported to a CSV file.
#
# Required Libraries:
# - opencv-python: For image processing, color segmentation, and contour analysis.
# - numpy: For numerical operations.
#
# You can install them using pip:
# pip install opencv-python-headless numpy
import cv2
import numpy as np
import math
import csv # Import the csv module
def estimate_vehicle_states_by_color(image_path):
"""
Detects vehicles in an image based on color, and estimates their position and heading.
Args:
image_path (str): The path to the input image.
Returns:
tuple: A tuple containing the annotated image and a list of vehicle states.
"""
# 1. Load the image
try:
img = cv2.imread(image_path)
if img is None:
print(f"Error: Could not read image from path: {image_path}")
return None, []
# Create a copy for drawing annotations
annotated_img = img.copy()
except Exception as e:
print(f"Error loading image: {e}")
return None, []
# 2. Convert the image to HSV color space
# HSV (Hue, Saturation, Value) is often easier for color segmentation
# than the default BGR format.
hsv_img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# 3. Define color ranges for the vehicles
# These values are tuned for the specific green and blue in the provided image.
# Format: [Hue, Saturation, Value]
# Green vehicle (Ego)
lower_green = np.array([50, 100, 100])
upper_green = np.array([70, 255, 255])
# Blue vehicles (Corrected Range)
lower_blue = np.array([85, 100, 100])
upper_blue = np.array([110, 255, 255])
# 4. Create masks for each color
mask_green = cv2.inRange(hsv_img, lower_green, upper_green)
mask_blue = cv2.inRange(hsv_img, lower_blue, upper_blue)
# 5. Find contours for each mask separately for robust classification
contours_green, _ = cv2.findContours(mask_green, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours_blue, _ = cv2.findContours(mask_blue, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Combine contours with their respective class names
all_contours = []
for c in contours_green:
all_contours.append((c, "ego_vehicle"))
for c in contours_blue:
all_contours.append((c, "other_vehicle"))
vehicle_states = []
# 6. Iterate through each detected contour
for contour, class_name in all_contours:
# Filter out very small contours that might be noise
if cv2.contourArea(contour) < 50:
continue
# --- State Estimation ---
# a. Get the minimum area rotated rectangle
# This is perfect for finding the orientation of non-upright rectangles.
rect = cv2.minAreaRect(contour)
(pos_x, pos_y), _, _ = rect
# Get the 4 corners of the rotated rectangle for drawing and heading calculation
box_points = cv2.boxPoints(rect)
box_points = np.intp(box_points)
# b. Heading (Robust Calculation)
# We find the longer side of the rectangle and calculate its angle.
edge1 = np.linalg.norm(box_points[0] - box_points[1])
edge2 = np.linalg.norm(box_points[1] - box_points[2])
# Determine the vector corresponding to the vehicle's length (the longer side)
if edge1 > edge2:
delta_x = box_points[1][0] - box_points[0][0]
delta_y = box_points[1][1] - box_points[0][1]
else:
delta_x = box_points[2][0] - box_points[1][0]
delta_y = box_points[2][1] - box_points[1][1]
# Calculate the angle of this vector
angle_rad = math.atan2(delta_y, delta_x)
heading = math.degrees(angle_rad)
# As all vehicles in highway-env move to the right, we ensure the
# heading is in the right-hand plane (between -90 and 90 degrees).
if heading > 90:
heading -= 180
elif heading < -90:
heading += 180
# c. Speed
# Speed calculation requires tracking across multiple frames.
# Since we only have one frame, we'll set it to 0.
speed = 0.0 # Placeholder
# Store the state
vehicle_states.append({
"class": class_name,
"bounding_box_points": box_points.tolist(),
"position_x": pos_x,
"position_y": pos_y,
"speed": speed,
"heading": heading
})
# --- Visualization ---
# Draw the rotated bounding box
cv2.drawContours(annotated_img, [box_points], 0, (0, 255, 255), 2) # Yellow box
# Draw the center point
cv2.circle(annotated_img, (int(pos_x), int(pos_y)), 5, (0, 0, 255), -1) # Red dot
# Draw the heading vector
length = 40 # Length of the heading line
angle_rad_viz = np.deg2rad(heading) # Use the corrected heading for visualization
end_x = int(pos_x + length * np.cos(angle_rad_viz))
end_y = int(pos_y + length * np.sin(angle_rad_viz))
cv2.line(annotated_img, (int(pos_x), int(pos_y)), (end_x, end_y), (255, 0, 0), 2) # Blue line
# Put text label
label = f"H: {heading:.1f}"
cv2.putText(annotated_img, label, (box_points[1][0], box_points[1][1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)
return annotated_img, vehicle_states
def save_states_to_csv(states, csv_file_path):
"""
Saves the list of vehicle states to a CSV file.
Args:
states (list): A list of dictionaries, where each dictionary is a vehicle's state.
csv_file_path (str): The path to the output CSV file.
"""
# Define the fieldnames for the CSV header. We exclude the bounding box points.
fieldnames = ['class', 'position_x', 'position_y', 'heading', 'speed']
try:
with open(csv_file_path, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction='ignore')
writer.writeheader() # Write the header row
writer.writerows(states) # Write all the state data
print(f"\nVehicle states successfully saved to: {csv_file_path}")
except IOError as e:
print(f"Error writing to CSV file: {e}")
if __name__ == '__main__':
# Define input and output file paths
input_image_path = '/home/zhexiao/Documents/diamond/images/model_output.png'#'/home/zhexiao/Documents/highway_dataset/record_1/images/frame_000165.png'
output_image_path = 'frame_000001_contour_detected.png'
output_csv_path = 'vehicle_states.csv'
# Process the image to get states and the annotated image
annotated_image, states = estimate_vehicle_states_by_color(input_image_path)
if annotated_image is not None and states:
print("--- Detected Vehicle States (Contour Method) ---")
# Sort states by x-position for consistent ordering
states.sort(key=lambda v: v['position_x'])
for i, state in enumerate(states):
print(f"\nVehicle #{i+1}:")
print(f" Class: {state['class']}")
print(f" Position (x, y): ({state['position_x']:.2f}, {state['position_y']:.2f})")
print(f" Heading (degrees): {state['heading']:.2f}")
print(f" Speed: {state['speed']:.2f} (Note: Placeholder value)")
# Save the annotated image
cv2.imwrite(output_image_path, annotated_image)
print(f"\nAnnotated image saved to: {output_image_path}")
# Save the states to a CSV file
save_states_to_csv(states, output_csv_path)
elif not states:
print("No vehicles were detected in the image.")
# To display the image in a window (if you are running this on a desktop)
# if annotated_image is not None:
# cv2.imshow('Vehicle Detection', annotated_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()