Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import cv2 | |
import requests | |
import json | |
import base64 | |
from PIL import Image | |
import io | |
import os | |
from dotenv import load_dotenv | |
from collections import defaultdict | |
import time | |
# Load environment variables | |
load_dotenv() | |
# Define API endpoint from environment variable | |
API_URL = os.getenv("API_URL", "http://122.155.170.240:81") | |
print(f"Using API URL: {API_URL}") | |
DEFAULT_CONFIDENCE = float(os.getenv("DEFAULT_CONFIDENCE_THRESHOLD", "0.25")) | |
def calculate_iou(box1, box2): | |
"""Calculate Intersection over Union (IoU) between two bounding boxes""" | |
x1 = max(box1[0], box2[0]) | |
y1 = max(box1[1], box2[1]) | |
x2 = min(box1[2], box2[2]) | |
y2 = min(box1[3], box2[3]) | |
intersection = max(0, x2 - x1) * max(0, y2 - y1) | |
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
union = area1 + area2 - intersection | |
return intersection / union if union > 0 else 0 | |
def calculate_bbox_similarity(bbox1, bbox2): | |
"""Calculate similarity between two bounding boxes using IoU and center distance""" | |
try: | |
# Calculate IoU | |
iou = calculate_iou(bbox1, bbox2) | |
# Calculate center distance | |
center1 = get_box_center(bbox1) | |
center2 = get_box_center(bbox2) | |
if center1 is None or center2 is None: | |
return 0.0 | |
distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) | |
# Normalize distance based on bbox size | |
bbox_size = max(bbox1[2] - bbox1[0], bbox1[3] - bbox1[1]) | |
normalized_distance = distance / max(bbox_size, 1) | |
# Combine IoU and distance for final similarity score | |
similarity = iou * 0.7 + max(0, 1 - normalized_distance * 0.3) * 0.3 | |
return similarity | |
except Exception as e: | |
return 0.0 | |
def get_box_center(bbox): | |
"""Calculate center point of bounding box""" | |
try: | |
# Handle different bbox formats (x,y,w,h) or (x1,y1,x2,y2) | |
if len(bbox) == 4: | |
if bbox[2] < bbox[0] or bbox[3] < bbox[1]: # If it's x1,y1,x2,y2 format | |
x = (bbox[0] + bbox[2]) / 2 | |
y = (bbox[1] + bbox[3]) / 2 | |
else: # If it's x,y,w,h format | |
x = bbox[0] + bbox[2]/2 | |
y = bbox[1] + bbox[3]/2 | |
else: | |
return None | |
return (x, y) | |
except Exception as e: | |
return None | |
def calculate_movement(prev_center, curr_center, min_movement=10): | |
"""Calculate if there's significant movement between frames""" | |
try: | |
if prev_center is None or curr_center is None: | |
return False | |
dx = curr_center[0] - prev_center[0] | |
dy = curr_center[1] - prev_center[1] | |
distance = np.sqrt(dx*dx + dy*dy) | |
return distance > min_movement | |
except Exception as e: | |
return False | |
class TrackedObject: | |
def __init__(self, obj_id, obj_class, bbox): | |
self.id = obj_id | |
self.class_name = obj_class | |
self.trajectory = [] # List of center points | |
self.bboxes = [] # List of bounding boxes | |
self.counted = False | |
self.last_seen = 0 # Frame number when last seen | |
self.first_seen = 0 # Frame number when first seen | |
self.frames_in_red_zone = 0 # Number of consecutive frames in red zone | |
self.warning_triggered = False # Whether warning has been triggered | |
self.red_zone_entry_frame = None # Frame when object entered red zone | |
self.similarity_scores = [] # Track similarity scores over time | |
self.add_detection(bbox) | |
def add_detection(self, bbox): | |
try: | |
center = get_box_center(bbox) | |
if center is not None: | |
self.trajectory.append(center) | |
self.bboxes.append(bbox) | |
# Keep only recent history to prevent memory issues | |
if len(self.trajectory) > 50: | |
self.trajectory = self.trajectory[-25:] | |
self.bboxes = self.bboxes[-25:] | |
except Exception as e: | |
pass | |
def has_movement(self, min_movement=10): | |
try: | |
if len(self.trajectory) < 2: | |
return False | |
return calculate_movement(self.trajectory[-2], self.trajectory[-1], min_movement) | |
except Exception as e: | |
return False | |
def update_red_zone_status(self, is_in_red_zone, frame_number): | |
"""Update red zone status and handle warnings""" | |
if is_in_red_zone: | |
if self.red_zone_entry_frame is None: | |
self.red_zone_entry_frame = frame_number | |
self.frames_in_red_zone += 1 | |
# Check if warning should be triggered | |
if self.frames_in_red_zone > 3 and not self.warning_triggered: | |
self.warning_triggered = True | |
return True # Return True to indicate warning should be shown | |
else: | |
# Object left red zone, reset counters | |
self.frames_in_red_zone = 0 | |
self.red_zone_entry_frame = None | |
self.warning_triggered = False | |
return False | |
def get_similarity_with(self, other_bbox, similarity_threshold=0.5): | |
"""Calculate similarity with another bounding box""" | |
if len(self.bboxes) == 0: | |
return 0.0 | |
current_bbox = self.bboxes[-1] | |
return calculate_bbox_similarity(current_bbox, other_bbox) | |
def is_similar_object(obj1, obj2, similarity_threshold=0.6): | |
"""Check if two objects are similar based on class, position and bounding box similarity""" | |
try: | |
if obj1['class'] != obj2['class']: | |
return False | |
box1 = obj1['bbox'] | |
box2 = obj2['bbox'] | |
# Convert to x1,y1,x2,y2 format if needed | |
if len(box1) == 4 and len(box2) == 4: | |
if box1[2] < box1[0] or box1[3] < box1[1]: # Already in x1,y1,x2,y2 | |
bbox1 = box1 | |
else: # Convert from x,y,w,h to x1,y1,x2,y2 | |
bbox1 = [box1[0], box1[1], box1[0] + box1[2], box1[1] + box1[3]] | |
if box2[2] < box2[0] or box2[3] < box2[1]: # Already in x1,y1,x2,y2 | |
bbox2 = box2 | |
else: # Convert from x,y,w,h to x1,y1,x2,y2 | |
bbox2 = [box2[0], box2[1], box2[0] + box2[2], box2[1] + box2[3]] | |
similarity = calculate_bbox_similarity(bbox1, bbox2) | |
return similarity > similarity_threshold | |
return False | |
except Exception as e: | |
return False | |
# Global state for protection area and previous detections | |
class State: | |
def __init__(self): | |
self.protection_points = [] # Store clicked points | |
self.detected_segments = [] | |
self.segment_image = None | |
self.selected_segments = [] | |
self.previous_detections = None | |
self.cached_protection_area = None | |
self.current_image = None # Store current image for drawing | |
self.original_dims = None # Store original image dimensions | |
self.display_dims = None # Store display dimensions | |
self.tracked_objects = {} # Dictionary of tracked objects | |
self.next_obj_id = 0 # Counter for generating unique object IDs | |
self.object_count = defaultdict(int) # Count by class | |
self.frame_count = 0 # Count processed frames | |
self.red_zone_passed_objects = defaultdict(int) # Objects that passed through red zone | |
self.red_zone_warnings = [] # Store warning messages | |
self.time_window = 10 # Configurable time window for similarity comparison | |
self.similarity_threshold = 0.6 # Configurable similarity threshold | |
def reset_tracking(self): | |
"""Reset all tracking data""" | |
self.tracked_objects = {} | |
self.next_obj_id = 0 | |
self.object_count = defaultdict(int) | |
self.frame_count = 0 | |
self.red_zone_passed_objects = defaultdict(int) | |
self.red_zone_warnings = [] | |
state = State() | |
def image_to_bytes(image): | |
"""Convert PIL Image to bytes for API request""" | |
# Log original image size | |
original_width, original_height = image.size | |
print(f"Original image dimensions: {original_width}x{original_height}") | |
# Convert image to bytes without resizing | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
print(f"Sending image with original dimensions: {original_width}x{original_height}") | |
return img_byte_arr.getvalue() | |
def base64_to_image(base64_str): | |
"""Convert base64 string to OpenCV image""" | |
img_data = base64.b64decode(base64_str) | |
nparr = np.frombuffer(img_data, np.uint8) | |
return cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
def opencv_to_pil(opencv_image): | |
"""Convert OpenCV image to PIL format""" | |
# Convert from BGR to RGB for PIL | |
rgb_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) | |
return Image.fromarray(rgb_image) | |
def scale_point_to_original(x, y): | |
"""Scale display coordinates back to original image coordinates""" | |
if state.original_dims is None or state.display_dims is None: | |
return x, y | |
orig_w, orig_h = state.original_dims | |
disp_w, disp_h = state.display_dims | |
# Calculate scaling factors | |
scale_x = orig_w / disp_w | |
scale_y = orig_h / disp_h | |
# Scale the coordinates | |
orig_x = int(x * scale_x) | |
orig_y = int(y * scale_y) | |
return orig_x, orig_y | |
def scale_points_to_display(points): | |
"""Scale points from original image coordinates to display coordinates""" | |
if state.original_dims is None or state.display_dims is None: | |
return points | |
orig_w, orig_h = state.original_dims | |
disp_w, disp_h = state.display_dims | |
# Calculate scaling factors | |
scale_x = disp_w / orig_w | |
scale_y = disp_h / orig_h | |
# Scale all points | |
display_points = [] | |
for point in points: | |
x = int(point[0] * scale_x) | |
y = int(point[1] * scale_y) | |
display_points.append([x, y]) | |
return display_points | |
def draw_protection_area(image): | |
"""Draw protection area points and lines on the image""" | |
img = image.copy() | |
points = state.protection_points | |
# Draw existing points and lines | |
if len(points) > 0: | |
# Convert points to numpy array | |
points_array = np.array(points, dtype=np.int32) | |
# Draw lines between points | |
if len(points) > 1: | |
cv2.polylines(img, [points_array], | |
True if len(points) == 4 else False, | |
(0, 255, 0), 2) | |
# Draw points with numbers | |
for i, point in enumerate(points): | |
cv2.circle(img, tuple(point), 5, (0, 0, 255), -1) | |
cv2.putText(img, str(i+1), | |
(point[0]+10, point[1]+10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
# Fill polygon with semi-transparent color if we have at least 3 points | |
if len(points) >= 3: | |
overlay = img.copy() | |
cv2.fillPoly(overlay, [points_array], (0, 255, 0)) | |
cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) | |
return img | |
def update_preview(video): | |
if video is None: | |
return None, [], gr.update(visible=False) | |
cap = cv2.VideoCapture(video) | |
ret, frame = cap.read() | |
cap.release() | |
if ret: | |
# Reset state | |
state.protection_points = [] | |
state.detected_segments = [] | |
state.segment_image = None | |
state.selected_segments = [] | |
state.previous_detections = None | |
state.cached_protection_area = None | |
# Store original frame and its dimensions | |
state.current_image = frame.copy() # Store the original frame | |
state.original_dims = (frame.shape[1], frame.shape[0]) # (width, height) | |
state.display_dims = state.original_dims # Set display dims same as original | |
# Convert to RGB without resizing | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
return frame_rgb, gr.update(choices=[], value=[], visible=False) | |
return None, gr.update(choices=[], value=[], visible=False) | |
def handle_image_click(evt: gr.SelectData, img): | |
"""Handle mouse clicks on the image""" | |
if len(state.protection_points) >= 4: | |
# Reset points if we already have 4 | |
state.protection_points = [] | |
if state.current_image is None: | |
return img, "Error: No image loaded" | |
# Get click coordinates from the event - these are now in original scale | |
click_x, click_y = evt.index[0], evt.index[1] | |
# Add point directly (no scaling needed as we're working with original coordinates) | |
state.protection_points.append([click_x, click_y]) | |
# Create a copy of the current image for display | |
display_img = state.current_image.copy() | |
# Draw points and lines | |
for i, point in enumerate(state.protection_points): | |
# Draw point | |
cv2.circle(display_img, (point[0], point[1]), 5, (0, 0, 255), -1) | |
cv2.putText(display_img, str(i+1), | |
(point[0] + 10, point[1] + 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
# Draw lines between points | |
if len(state.protection_points) > 1: | |
points_array = np.array(state.protection_points, dtype=np.int32) | |
# Draw lines | |
cv2.polylines(display_img, [points_array], | |
True if len(state.protection_points) == 4 else False, | |
(0, 255, 0), 2) | |
# Fill polygon with semi-transparent color if we have at least 3 points | |
if len(state.protection_points) >= 3: | |
overlay = display_img.copy() | |
cv2.fillPoly(overlay, [points_array], (0, 255, 0)) | |
cv2.addWeighted(overlay, 0.3, display_img, 0.7, 0, display_img) | |
# Convert to RGB for display | |
display_img_rgb = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB) | |
# Return the image and status | |
return display_img_rgb, f"Selected {len(state.protection_points)} points\nCoordinates: {state.protection_points}" | |
def reset_points(): | |
"""Reset protection points""" | |
state.protection_points = [] | |
if state.current_image is not None: | |
# Convert original image to RGB for display | |
display_img_rgb = cv2.cvtColor(state.current_image.copy(), cv2.COLOR_BGR2RGB) | |
return display_img_rgb, "Points reset" | |
return None, "Points reset" | |
def detect_rail_segments(image): | |
"""Detect rail segments using the API""" | |
try: | |
# Log original image dimensions | |
width, height = image.size | |
print(f"Detecting rail segments on image with dimensions: {width}x{height}") | |
files = {"file": image_to_bytes(image)} | |
response = requests.post( | |
f"{API_URL}/detect/rail-segment", | |
files=files, | |
timeout=60 | |
) | |
if response.status_code == 200: | |
result = response.json() | |
if "segments" in result: | |
return result["segments"], base64_to_image(result["image_base64"]) | |
else: | |
return [], None | |
else: | |
print(f"API error: {response.status_code} - Image size was {width}x{height}") | |
return [], None | |
except Exception as e: | |
print(f"Error in detect_rail_segments: {str(e)}") | |
return [], None | |
def extract_protection_area(first_frame): | |
"""Extract and cache protection area points using rail segment detection""" | |
try: | |
# Log original frame dimensions | |
height, width = first_frame.shape[:2] | |
print(f"Extracting protection area from frame with dimensions: {width}x{height}") | |
# Convert frame to PIL Image without resizing | |
first_frame_pil = Image.fromarray(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)) | |
# Verify PIL image dimensions | |
pil_width, pil_height = first_frame_pil.size | |
print(f"PIL Image dimensions before API call: {pil_width}x{pil_height}") | |
# Detect rail segments | |
segments, segment_img = detect_rail_segments(first_frame_pil) | |
if segments and len(segments) > 0: | |
# Verify segment image dimensions | |
if segment_img is not None: | |
seg_height, seg_width = segment_img.shape[:2] | |
print(f"Received segment image dimensions: {seg_width}x{seg_height}") | |
# Only resize if dimensions don't match | |
if (seg_width, seg_height) != (width, height): | |
print(f"Resizing segment image from {seg_width}x{seg_height} to {width}x{height}") | |
segment_img = cv2.resize(segment_img, (width, height), interpolation=cv2.INTER_LANCZOS4) | |
# Store segments and image | |
state.detected_segments = segments | |
state.segment_image = segment_img | |
# Create segment choices with more detailed information | |
segment_choices = [] | |
for i, segment in enumerate(segments): | |
# Extract mask dimensions for verification | |
mask_points = segment.get('mask', []) | |
if mask_points: | |
mask_x = [p[0] for p in mask_points] | |
mask_y = [p[1] for p in mask_points] | |
mask_width = max(mask_x) - min(mask_x) | |
mask_height = max(mask_y) - min(mask_y) | |
print(f"Segment {i+1} mask dimensions: {mask_width}x{mask_height}") | |
choice_text = f"Segment {i+1} (Confidence: {segment['confidence']:.2f})" | |
segment_choices.append(choice_text) | |
state.selected_segments = segment_choices # Select all segments by default | |
# Use the first segment's mask as protection area | |
segment = segments[0] | |
if 'mask' in segment and segment['mask']: | |
mask_points = segment['mask'] | |
# Convert to list of [x,y] points and ensure integer values | |
mask_points = [[int(float(x)), int(float(y))] for x, y in mask_points] | |
if len(mask_points) >= 3: # Need at least 3 points for a valid polygon | |
state.cached_protection_area = mask_points | |
# Convert segment image to RGB for display without resizing | |
if segment_img is not None: | |
display_img = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB) | |
return True, "Protection area extracted successfully", display_img | |
return False, "Invalid mask points in segment", None | |
return False, "No valid rail segments detected", None | |
except Exception as e: | |
print(f"Error in extract_protection_area: {str(e)}") | |
return False, f"Error extracting protection area: {str(e)}", None | |
def get_segment_index(choice_text): | |
"""Extract segment index from choice text""" | |
try: | |
# Extract index from "Segment X (Confidence: Y)" format | |
return int(choice_text.split()[1]) - 1 | |
except: | |
return -1 | |
def update_object_tracking(objects_in_area): | |
"""Update object tracking with new detections""" | |
try: | |
current_tracked = set() # Keep track of objects seen in this frame | |
current_warnings = [] # Collect warnings for this frame | |
# Match new detections with existing tracked objects | |
for obj in objects_in_area: | |
try: | |
if 'bbox' not in obj or 'class' not in obj: | |
continue | |
bbox = obj['bbox'] | |
obj_class = obj['class'] | |
is_in_red_zone = obj.get('in_protection_area', False) | |
matched = False | |
best_match_id = None | |
best_similarity = 0.0 | |
# Try to match with existing tracked objects using similarity method | |
for obj_id, tracked in state.tracked_objects.items(): | |
if tracked.class_name == obj_class: | |
# Check if object was seen recently (within time window) | |
if state.frame_count - tracked.last_seen <= state.time_window: | |
similarity = tracked.get_similarity_with(bbox) | |
# Use the best match above threshold | |
if similarity > state.similarity_threshold and similarity > best_similarity: | |
best_similarity = similarity | |
best_match_id = obj_id | |
# If good match found, update existing object | |
if best_match_id is not None: | |
tracked = state.tracked_objects[best_match_id] | |
tracked.add_detection(bbox) | |
tracked.last_seen = state.frame_count | |
current_tracked.add(best_match_id) | |
matched = True | |
# Check red zone status and warnings | |
warning_triggered = tracked.update_red_zone_status(is_in_red_zone, state.frame_count) | |
if warning_triggered: | |
warning_msg = f"β οΈ WARNING: {tracked.class_name} (ID: {tracked.id}) has been in red zone for {tracked.frames_in_red_zone} frames!" | |
current_warnings.append(warning_msg) | |
state.red_zone_warnings.append({ | |
'frame': state.frame_count, | |
'object_id': tracked.id, | |
'class': tracked.class_name, | |
'frames_in_zone': tracked.frames_in_red_zone, | |
'message': warning_msg | |
}) | |
# Check if object should be counted (only count objects that actually move through the zone) | |
if not tracked.counted and tracked.has_movement() and is_in_red_zone: | |
# Additional check: object should have been tracked for at least a few frames | |
if len(tracked.trajectory) >= 3: | |
tracked.counted = True | |
state.red_zone_passed_objects[obj_class] += 1 | |
# If no match found, create new tracked object | |
if not matched: | |
new_obj = TrackedObject(state.next_obj_id, obj_class, bbox) | |
new_obj.last_seen = state.frame_count | |
new_obj.first_seen = state.frame_count | |
state.tracked_objects[state.next_obj_id] = new_obj | |
current_tracked.add(state.next_obj_id) | |
state.next_obj_id += 1 | |
# Check red zone status for new object | |
new_obj.update_red_zone_status(is_in_red_zone, state.frame_count) | |
except Exception as e: | |
continue | |
# Update objects not seen in current frame | |
for obj_id, tracked in state.tracked_objects.items(): | |
if obj_id not in current_tracked: | |
# Object not seen in current frame, update red zone status | |
tracked.update_red_zone_status(False, state.frame_count) | |
# Remove objects that haven't been seen for a while | |
if state.frame_count > state.time_window: | |
to_remove = [] | |
for obj_id, tracked in state.tracked_objects.items(): | |
if state.frame_count - tracked.last_seen > state.time_window * 2: # Remove after 2x time window | |
to_remove.append(obj_id) | |
for obj_id in to_remove: | |
del state.tracked_objects[obj_id] | |
# Store current warnings | |
if current_warnings: | |
print(f"Frame {state.frame_count} Warnings: {current_warnings}") | |
except Exception as e: | |
print(f"Error in update_object_tracking: {str(e)}") | |
def get_red_zone_summary(): | |
"""Generate summary of objects that passed through red zone""" | |
summary = [] | |
if state.red_zone_passed_objects: | |
summary.append("π΄ RED ZONE PASSAGE SUMMARY:") | |
total_objects = sum(state.red_zone_passed_objects.values()) | |
summary.append(f"Total objects passed: {total_objects}") | |
for obj_class, count in sorted(state.red_zone_passed_objects.items()): | |
summary.append(f" β’ {obj_class}: {count}") | |
# Add current objects in red zone | |
current_in_zone = [] | |
for obj_id, tracked in state.tracked_objects.items(): | |
if tracked.frames_in_red_zone > 0: | |
current_in_zone.append(f"{tracked.class_name} (ID: {tracked.id}, {tracked.frames_in_red_zone} frames)") | |
if current_in_zone: | |
summary.append("\nπ¨ CURRENTLY IN RED ZONE:") | |
for obj_info in current_in_zone: | |
summary.append(f" β’ {obj_info}") | |
# Add recent warnings | |
recent_warnings = [w for w in state.red_zone_warnings if state.frame_count - w['frame'] <= 5] | |
if recent_warnings: | |
summary.append("\nβ οΈ RECENT WARNINGS:") | |
for warning in recent_warnings[-3:]: # Show last 3 warnings | |
summary.append(f" β’ Frame {warning['frame']}: {warning['message']}") | |
return "\n".join(summary) if summary else "No objects detected in red zone yet." | |
def process_frame(frame, confidence): | |
"""Process a video frame using cached protection area""" | |
try: | |
protection_area = [] | |
if state.selected_segments and state.detected_segments: | |
for choice in state.selected_segments: | |
idx = get_segment_index(choice) | |
if 0 <= idx < len(state.detected_segments): | |
segment = state.detected_segments[idx] | |
if 'mask' in segment and segment['mask']: | |
protection_area = segment['mask'] | |
break | |
elif len(state.protection_points) >= 3: | |
protection_area = state.protection_points | |
if not protection_area: | |
return None, "Protection area not set. Please extract protection area first." | |
# Ensure frame is valid | |
if frame is None or frame.size == 0: | |
return None, "Invalid frame" | |
success, buffer = cv2.imencode('.png', frame) | |
if not success: | |
return None, "Failed to encode frame" | |
files = { | |
"file": ("frame.png", buffer.tobytes(), "image/png") | |
} | |
protection_area_json = json.dumps(protection_area) | |
data = { | |
"protection_area": protection_area_json, | |
"confidence_threshold": str(confidence) | |
} | |
if state.previous_detections: | |
data["previous_detections"] = json.dumps(state.previous_detections) | |
try: | |
response = requests.post( | |
f"{API_URL}/detect/objects-and-redlight", | |
files=files, | |
data=data, | |
timeout=60 | |
) | |
if response.status_code == 200: | |
result = response.json() | |
if not result.get("success"): | |
return None, f"API Error: {result.get('detail', 'Unknown error')}" | |
result_data = result.get("result", {}) | |
if not result_data: | |
return None, "No result data received" | |
red_light_info = result_data.get("red_light", {}) | |
red_light_detected = red_light_info.get("detected", False) | |
red_light_prob = red_light_info.get("probability", 0) | |
img_base64 = result_data.get("image_base64") | |
if not img_base64: | |
return None, "No image data received from API" | |
try: | |
if ',' in img_base64: | |
img_base64 = img_base64.split(',')[1] | |
img_data = base64.b64decode(img_base64) | |
nparr = np.frombuffer(img_data, np.uint8) | |
processed_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
if processed_img is None or processed_img.size == 0: | |
return None, "Failed to decode image from API response" | |
objects_in_area = [obj for obj in result_data.get("objects", []) | |
if obj.get("in_protection_area", False) and | |
'bbox' in obj and 'class' in obj] | |
# Update object tracking | |
state.frame_count += 1 | |
update_object_tracking(objects_in_area) | |
# Cache detections for next frame | |
state.previous_detections = objects_in_area | |
processed_img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB) | |
status = [] | |
status.append(f"Red Light: {'YES' if red_light_detected else 'NO'} ({red_light_prob:.2f})") | |
# Add enhanced red zone summary | |
red_zone_summary = get_red_zone_summary() | |
status.append(f"\n{red_zone_summary}") | |
if objects_in_area: | |
status.append("\nπ CURRENT FRAME DETECTIONS:") | |
for obj in objects_in_area: | |
status.append(f" β’ {obj['class']} (confidence: {obj['confidence']:.2f})") | |
# Add tracking statistics | |
active_objects = len([obj for obj in state.tracked_objects.values() | |
if state.frame_count - obj.last_seen <= 3]) | |
status.append(f"\nπ TRACKING STATS:") | |
status.append(f" β’ Active tracked objects: {active_objects}") | |
status.append(f" β’ Frame: {state.frame_count}") | |
status.append(f" β’ Time window: {state.time_window} frames") | |
status.append(f" β’ Similarity threshold: {state.similarity_threshold:.2f}") | |
return processed_img_rgb, "\n".join(status) | |
except Exception as e: | |
return None, f"Error processing detection results: {str(e)}" | |
else: | |
error_detail = f"API Error: {response.status_code}" | |
try: | |
error_json = response.json() | |
if 'detail' in error_json: | |
error_detail += f" - {error_json['detail']}" | |
except: | |
error_detail += f" - {response.text}" | |
return None, error_detail | |
except requests.exceptions.Timeout: | |
return None, "API request timed out" | |
except requests.exceptions.ConnectionError: | |
return None, "Could not connect to API server" | |
except Exception as e: | |
return None, f"API request failed: {str(e)}" | |
except Exception as e: | |
return None, f"Error processing frame: {str(e)}" | |
def process_video(video, confidence=DEFAULT_CONFIDENCE, target_fps=1): | |
"""Stream processed frames in real-time using cached protection area""" | |
detection_results = [] | |
cap = cv2.VideoCapture(video) | |
if not cap.isOpened(): | |
yield None, "Error: Could not open video file" | |
return | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
frame_interval = max(1, int(fps / target_fps)) | |
frame_number = 0 | |
try: | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_number += 1 | |
if frame_number % frame_interval != 0: | |
continue | |
# Process frame and get results | |
processed_frame, result = process_frame(frame, confidence) | |
if processed_frame is not None: | |
# Frame is already in RGB format from process_frame | |
current_status = f"Processing frame {frame_number}/{total_frames}\n{result}" | |
yield processed_frame, current_status | |
else: | |
current_status = f"Frame {frame_number}: {result}" | |
yield None, current_status | |
# Release resources | |
cap.release() | |
# Generate final summary | |
final_summary = generate_final_summary() | |
yield None, final_summary | |
except Exception as e: | |
yield None, f"Error processing video: {str(e)}" | |
finally: | |
cap.release() | |
def generate_final_summary(): | |
"""Generate comprehensive final summary of video processing""" | |
summary_lines = [] | |
summary_lines.append("π¬ VIDEO PROCESSING COMPLETE") | |
summary_lines.append("=" * 50) | |
# Processing statistics | |
summary_lines.append(f"π PROCESSING STATISTICS:") | |
summary_lines.append(f" β’ Total frames processed: {state.frame_count}") | |
summary_lines.append(f" β’ Time window used: {state.time_window} frames") | |
summary_lines.append(f" β’ Similarity threshold: {state.similarity_threshold:.2f}") | |
# Red zone passage summary | |
if state.red_zone_passed_objects: | |
summary_lines.append(f"\nπ΄ RED ZONE PASSAGE SUMMARY:") | |
total_passed = sum(state.red_zone_passed_objects.values()) | |
summary_lines.append(f" β’ Total objects passed through red zone: {total_passed}") | |
for obj_class, count in sorted(state.red_zone_passed_objects.items()): | |
summary_lines.append(f" - {obj_class}: {count}") | |
else: | |
summary_lines.append(f"\nπ΄ RED ZONE PASSAGE SUMMARY:") | |
summary_lines.append(f" β’ No objects detected passing through red zone") | |
# Warning summary | |
if state.red_zone_warnings: | |
summary_lines.append(f"\nβ οΈ WARNING SUMMARY:") | |
summary_lines.append(f" β’ Total warnings generated: {len(state.red_zone_warnings)}") | |
# Group warnings by object class | |
warning_by_class = defaultdict(int) | |
for warning in state.red_zone_warnings: | |
warning_by_class[warning['class']] += 1 | |
for obj_class, count in sorted(warning_by_class.items()): | |
summary_lines.append(f" - {obj_class}: {count} warnings") | |
# Show last few warnings | |
if len(state.red_zone_warnings) > 0: | |
summary_lines.append(f"\n π Recent warnings:") | |
for warning in state.red_zone_warnings[-5:]: # Last 5 warnings | |
summary_lines.append(f" - Frame {warning['frame']}: {warning['class']} (ID: {warning['object_id']}) - {warning['frames_in_zone']} frames in zone") | |
else: | |
summary_lines.append(f"\nβ οΈ WARNING SUMMARY:") | |
summary_lines.append(f" β’ No warnings generated (no objects stayed in red zone > 3 frames)") | |
# Active tracking summary | |
total_tracked = len(state.tracked_objects) | |
if total_tracked > 0: | |
summary_lines.append(f"\nπ OBJECT TRACKING SUMMARY:") | |
summary_lines.append(f" β’ Total unique objects tracked: {total_tracked}") | |
# Group by class | |
objects_by_class = defaultdict(int) | |
for obj in state.tracked_objects.values(): | |
objects_by_class[obj.class_name] += 1 | |
for obj_class, count in sorted(objects_by_class.items()): | |
summary_lines.append(f" - {obj_class}: {count}") | |
summary_lines.append("\nβ Processing completed successfully!") | |
return "\n".join(summary_lines) | |
def extract_area_from_video(video): | |
if video is None: | |
return None, "Please upload a video", gr.update(choices=[], value=[], visible=False) | |
cap = cv2.VideoCapture(video) | |
ret, frame = cap.read() | |
cap.release() | |
if not ret: | |
return None, "Could not read video frame", gr.update(choices=[], value=[], visible=False) | |
success, message, segment_img = extract_protection_area(frame) | |
if success and segment_img is not None: | |
# Convert segment image to RGB for display | |
segment_img_rgb = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB) | |
# Create segment choices | |
segment_choices = [f"Segment {i+1} (Confidence: {segment['confidence']:.2f})" | |
for i, segment in enumerate(state.detected_segments)] | |
return segment_img_rgb, message, gr.update(choices=segment_choices, value=segment_choices, visible=True) | |
return None, message, gr.update(choices=[], value=[], visible=False) | |
def update_selected_segments(selected): | |
if selected is None: | |
selected = [] | |
state.selected_segments = selected | |
return gr.update() | |
def process_video_wrapper(video, confidence=DEFAULT_CONFIDENCE, target_fps=1, time_window=10, similarity_threshold=0.6): | |
"""Wrapper around process_video to handle full-size video processing""" | |
if video is None: | |
yield None, "Please upload a video" | |
return | |
# Reset tracking state and update parameters | |
state.reset_tracking() | |
state.time_window = time_window | |
state.similarity_threshold = similarity_threshold | |
protection_area = [] | |
if state.selected_segments and state.detected_segments: | |
for choice in state.selected_segments: | |
idx = get_segment_index(choice) | |
if 0 <= idx < len(state.detected_segments): | |
segment = state.detected_segments[idx] | |
if 'mask' in segment and segment['mask']: | |
protection_area = segment['mask'] | |
break | |
elif len(state.protection_points) >= 3: | |
protection_area = state.protection_points | |
if not protection_area: | |
yield None, "Please extract protection area first" | |
return | |
try: | |
yield None, f"π Starting video processing...\nβοΈ Time window: {time_window} frames\nβοΈ Similarity threshold: {similarity_threshold:.2f}" | |
for frame, status in process_video(video, confidence, target_fps): | |
yield frame, status | |
except Exception as e: | |
yield None, f"Error processing video: {str(e)}" | |
# Update the Gradio interface | |
with gr.Blocks(title="Enhanced Rail Traffic Monitor") as demo: | |
gr.Markdown(""" | |
# Enhanced Rail Traffic Monitoring System | |
## Features: | |
- **Smart Object Tracking**: Uses similarity method to track objects across frames | |
- **Red Zone Monitoring**: Counts objects passing through the red zone | |
- **Warning System**: Alerts when objects stay in red zone for more than 3 frames | |
- **Configurable Parameters**: Adjust time window and similarity threshold | |
## Setup Instructions: | |
**Method 1 (Manual Protection Area):** | |
1. Click 4 points on the image to define protection area | |
2. Click "Reset Points" to start over | |
**Method 2 (Automatic Detection):** | |
1. Click "Extract Protection Area" to automatically detect rail segments | |
**Processing:** | |
3. Adjust detection confidence, processing frame rate, time window, and similarity threshold | |
4. Click "Process Video" to analyze | |
The system will show real-time results including: | |
- Objects currently in red zone | |
- Total count of objects that passed through | |
- Warnings for objects staying too long in red zone | |
- Tracking statistics | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Input Video" | |
) | |
with gr.Row(): | |
confidence = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=DEFAULT_CONFIDENCE, | |
label="Detection Confidence Threshold", | |
info="Minimum confidence for object detection" | |
) | |
fps_slider = gr.Slider( | |
minimum=1, | |
maximum=30, | |
value=1, | |
step=1, | |
label="Processing Frame Rate (FPS)", | |
info="Frames per second to process" | |
) | |
with gr.Row(): | |
time_window_slider = gr.Slider( | |
minimum=5, | |
maximum=50, | |
value=10, | |
step=1, | |
label="Time Window (frames)", | |
info="Number of frames to consider for object similarity" | |
) | |
similarity_threshold_slider = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=0.6, | |
step=0.05, | |
label="Similarity Threshold", | |
info="Threshold for considering objects as the same (higher = stricter)" | |
) | |
with gr.Column(): | |
preview_image = gr.Image( | |
label="Click to Select Protection Area (Original Size)", | |
interactive=True, | |
show_label=True | |
) | |
# Add segment selection dropdown | |
segment_dropdown = gr.Dropdown( | |
label="Selected Segments", | |
choices=[], | |
multiselect=True, | |
interactive=True, | |
visible=False, | |
value=[] | |
) | |
with gr.Row(): | |
reset_btn = gr.Button("Reset Points") | |
extract_btn = gr.Button("Extract Protection Area") | |
process_btn = gr.Button("π Process Video") | |
with gr.Row(): | |
video_output = gr.Image( | |
label="Live Processing Output", | |
streaming=True, | |
interactive=False, | |
show_label=True, | |
container=True, | |
show_download_button=True | |
) | |
text_output = gr.Textbox( | |
label="Detection Results & Red Zone Summary", | |
lines=15, | |
max_lines=20, | |
show_copy_button=True | |
) | |
# Handle video upload to populate preview | |
video_input.change( | |
fn=update_preview, | |
inputs=[video_input], | |
outputs=[preview_image, segment_dropdown] | |
) | |
extract_btn.click( | |
fn=extract_area_from_video, | |
inputs=[video_input], | |
outputs=[preview_image, text_output, segment_dropdown] | |
) | |
segment_dropdown.change( | |
fn=update_selected_segments, | |
inputs=[segment_dropdown], | |
outputs=[segment_dropdown] | |
) | |
process_btn.click( | |
fn=process_video_wrapper, | |
inputs=[video_input, confidence, fps_slider, time_window_slider, similarity_threshold_slider], | |
outputs=[video_output, text_output] | |
) | |
# Add click event handler | |
preview_image.select( | |
fn=handle_image_click, | |
inputs=[preview_image], | |
outputs=[preview_image, text_output] | |
) | |
# Add reset button handler | |
reset_btn.click( | |
fn=reset_points, | |
inputs=[], | |
outputs=[preview_image, text_output] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |