import gradio as gr import numpy as np import cv2 import requests import json import base64 from PIL import Image import io import os from dotenv import load_dotenv from collections import defaultdict import time # Load environment variables load_dotenv() # Define API endpoint from environment variable API_URL = os.getenv("API_URL", "http://122.155.170.240:81") print(f"Using API URL: {API_URL}") DEFAULT_CONFIDENCE = float(os.getenv("DEFAULT_CONFIDENCE_THRESHOLD", "0.25")) def calculate_iou(box1, box2): """Calculate Intersection over Union (IoU) between two bounding boxes""" x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) intersection = max(0, x2 - x1) * max(0, y2 - y1) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) union = area1 + area2 - intersection return intersection / union if union > 0 else 0 def calculate_bbox_similarity(bbox1, bbox2): """Calculate similarity between two bounding boxes using IoU and center distance""" try: # Calculate IoU iou = calculate_iou(bbox1, bbox2) # Calculate center distance center1 = get_box_center(bbox1) center2 = get_box_center(bbox2) if center1 is None or center2 is None: return 0.0 distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) # Normalize distance based on bbox size bbox_size = max(bbox1[2] - bbox1[0], bbox1[3] - bbox1[1]) normalized_distance = distance / max(bbox_size, 1) # Combine IoU and distance for final similarity score similarity = iou * 0.7 + max(0, 1 - normalized_distance * 0.3) * 0.3 return similarity except Exception as e: return 0.0 def get_box_center(bbox): """Calculate center point of bounding box""" try: # Handle different bbox formats (x,y,w,h) or (x1,y1,x2,y2) if len(bbox) == 4: if bbox[2] < bbox[0] or bbox[3] < bbox[1]: # If it's x1,y1,x2,y2 format x = (bbox[0] + bbox[2]) / 2 y = (bbox[1] + bbox[3]) / 2 else: # If it's x,y,w,h format x = bbox[0] + bbox[2]/2 y = bbox[1] + bbox[3]/2 else: return None return (x, y) except Exception as e: return None def calculate_movement(prev_center, curr_center, min_movement=10): """Calculate if there's significant movement between frames""" try: if prev_center is None or curr_center is None: return False dx = curr_center[0] - prev_center[0] dy = curr_center[1] - prev_center[1] distance = np.sqrt(dx*dx + dy*dy) return distance > min_movement except Exception as e: return False class TrackedObject: def __init__(self, obj_id, obj_class, bbox): self.id = obj_id self.class_name = obj_class self.trajectory = [] # List of center points self.bboxes = [] # List of bounding boxes self.counted = False self.last_seen = 0 # Frame number when last seen self.first_seen = 0 # Frame number when first seen self.frames_in_red_zone = 0 # Number of consecutive frames in red zone self.warning_triggered = False # Whether warning has been triggered self.red_zone_entry_frame = None # Frame when object entered red zone self.similarity_scores = [] # Track similarity scores over time self.add_detection(bbox) def add_detection(self, bbox): try: center = get_box_center(bbox) if center is not None: self.trajectory.append(center) self.bboxes.append(bbox) # Keep only recent history to prevent memory issues if len(self.trajectory) > 50: self.trajectory = self.trajectory[-25:] self.bboxes = self.bboxes[-25:] except Exception as e: pass def has_movement(self, min_movement=10): try: if len(self.trajectory) < 2: return False return calculate_movement(self.trajectory[-2], self.trajectory[-1], min_movement) except Exception as e: return False def update_red_zone_status(self, is_in_red_zone, frame_number): """Update red zone status and handle warnings""" if is_in_red_zone: if self.red_zone_entry_frame is None: self.red_zone_entry_frame = frame_number self.frames_in_red_zone += 1 # Check if warning should be triggered if self.frames_in_red_zone > 3 and not self.warning_triggered: self.warning_triggered = True return True # Return True to indicate warning should be shown else: # Object left red zone, reset counters self.frames_in_red_zone = 0 self.red_zone_entry_frame = None self.warning_triggered = False return False def get_similarity_with(self, other_bbox, similarity_threshold=0.5): """Calculate similarity with another bounding box""" if len(self.bboxes) == 0: return 0.0 current_bbox = self.bboxes[-1] return calculate_bbox_similarity(current_bbox, other_bbox) def is_similar_object(obj1, obj2, similarity_threshold=0.6): """Check if two objects are similar based on class, position and bounding box similarity""" try: if obj1['class'] != obj2['class']: return False box1 = obj1['bbox'] box2 = obj2['bbox'] # Convert to x1,y1,x2,y2 format if needed if len(box1) == 4 and len(box2) == 4: if box1[2] < box1[0] or box1[3] < box1[1]: # Already in x1,y1,x2,y2 bbox1 = box1 else: # Convert from x,y,w,h to x1,y1,x2,y2 bbox1 = [box1[0], box1[1], box1[0] + box1[2], box1[1] + box1[3]] if box2[2] < box2[0] or box2[3] < box2[1]: # Already in x1,y1,x2,y2 bbox2 = box2 else: # Convert from x,y,w,h to x1,y1,x2,y2 bbox2 = [box2[0], box2[1], box2[0] + box2[2], box2[1] + box2[3]] similarity = calculate_bbox_similarity(bbox1, bbox2) return similarity > similarity_threshold return False except Exception as e: return False # Global state for protection area and previous detections class State: def __init__(self): self.protection_points = [] # Store clicked points self.detected_segments = [] self.segment_image = None self.selected_segments = [] self.previous_detections = None self.cached_protection_area = None self.current_image = None # Store current image for drawing self.original_dims = None # Store original image dimensions self.display_dims = None # Store display dimensions self.tracked_objects = {} # Dictionary of tracked objects self.next_obj_id = 0 # Counter for generating unique object IDs self.object_count = defaultdict(int) # Count by class self.frame_count = 0 # Count processed frames self.red_zone_passed_objects = defaultdict(int) # Objects that passed through red zone self.red_zone_warnings = [] # Store warning messages self.time_window = 10 # Configurable time window for similarity comparison self.similarity_threshold = 0.6 # Configurable similarity threshold def reset_tracking(self): """Reset all tracking data""" self.tracked_objects = {} self.next_obj_id = 0 self.object_count = defaultdict(int) self.frame_count = 0 self.red_zone_passed_objects = defaultdict(int) self.red_zone_warnings = [] state = State() def image_to_bytes(image): """Convert PIL Image to bytes for API request""" # Log original image size original_width, original_height = image.size print(f"Original image dimensions: {original_width}x{original_height}") # Convert image to bytes without resizing img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') print(f"Sending image with original dimensions: {original_width}x{original_height}") return img_byte_arr.getvalue() def base64_to_image(base64_str): """Convert base64 string to OpenCV image""" img_data = base64.b64decode(base64_str) nparr = np.frombuffer(img_data, np.uint8) return cv2.imdecode(nparr, cv2.IMREAD_COLOR) def opencv_to_pil(opencv_image): """Convert OpenCV image to PIL format""" # Convert from BGR to RGB for PIL rgb_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) return Image.fromarray(rgb_image) def scale_point_to_original(x, y): """Scale display coordinates back to original image coordinates""" if state.original_dims is None or state.display_dims is None: return x, y orig_w, orig_h = state.original_dims disp_w, disp_h = state.display_dims # Calculate scaling factors scale_x = orig_w / disp_w scale_y = orig_h / disp_h # Scale the coordinates orig_x = int(x * scale_x) orig_y = int(y * scale_y) return orig_x, orig_y def scale_points_to_display(points): """Scale points from original image coordinates to display coordinates""" if state.original_dims is None or state.display_dims is None: return points orig_w, orig_h = state.original_dims disp_w, disp_h = state.display_dims # Calculate scaling factors scale_x = disp_w / orig_w scale_y = disp_h / orig_h # Scale all points display_points = [] for point in points: x = int(point[0] * scale_x) y = int(point[1] * scale_y) display_points.append([x, y]) return display_points def draw_protection_area(image): """Draw protection area points and lines on the image""" img = image.copy() points = state.protection_points # Draw existing points and lines if len(points) > 0: # Convert points to numpy array points_array = np.array(points, dtype=np.int32) # Draw lines between points if len(points) > 1: cv2.polylines(img, [points_array], True if len(points) == 4 else False, (0, 255, 0), 2) # Draw points with numbers for i, point in enumerate(points): cv2.circle(img, tuple(point), 5, (0, 0, 255), -1) cv2.putText(img, str(i+1), (point[0]+10, point[1]+10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) # Fill polygon with semi-transparent color if we have at least 3 points if len(points) >= 3: overlay = img.copy() cv2.fillPoly(overlay, [points_array], (0, 255, 0)) cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) return img def update_preview(video): if video is None: return None, [], gr.update(visible=False) cap = cv2.VideoCapture(video) ret, frame = cap.read() cap.release() if ret: # Reset state state.protection_points = [] state.detected_segments = [] state.segment_image = None state.selected_segments = [] state.previous_detections = None state.cached_protection_area = None # Store original frame and its dimensions state.current_image = frame.copy() # Store the original frame state.original_dims = (frame.shape[1], frame.shape[0]) # (width, height) state.display_dims = state.original_dims # Set display dims same as original # Convert to RGB without resizing frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return frame_rgb, gr.update(choices=[], value=[], visible=False) return None, gr.update(choices=[], value=[], visible=False) def handle_image_click(evt: gr.SelectData, img): """Handle mouse clicks on the image""" if len(state.protection_points) >= 4: # Reset points if we already have 4 state.protection_points = [] if state.current_image is None: return img, "Error: No image loaded" # Get click coordinates from the event - these are now in original scale click_x, click_y = evt.index[0], evt.index[1] # Add point directly (no scaling needed as we're working with original coordinates) state.protection_points.append([click_x, click_y]) # Create a copy of the current image for display display_img = state.current_image.copy() # Draw points and lines for i, point in enumerate(state.protection_points): # Draw point cv2.circle(display_img, (point[0], point[1]), 5, (0, 0, 255), -1) cv2.putText(display_img, str(i+1), (point[0] + 10, point[1] + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) # Draw lines between points if len(state.protection_points) > 1: points_array = np.array(state.protection_points, dtype=np.int32) # Draw lines cv2.polylines(display_img, [points_array], True if len(state.protection_points) == 4 else False, (0, 255, 0), 2) # Fill polygon with semi-transparent color if we have at least 3 points if len(state.protection_points) >= 3: overlay = display_img.copy() cv2.fillPoly(overlay, [points_array], (0, 255, 0)) cv2.addWeighted(overlay, 0.3, display_img, 0.7, 0, display_img) # Convert to RGB for display display_img_rgb = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB) # Return the image and status return display_img_rgb, f"Selected {len(state.protection_points)} points\nCoordinates: {state.protection_points}" def reset_points(): """Reset protection points""" state.protection_points = [] if state.current_image is not None: # Convert original image to RGB for display display_img_rgb = cv2.cvtColor(state.current_image.copy(), cv2.COLOR_BGR2RGB) return display_img_rgb, "Points reset" return None, "Points reset" def detect_rail_segments(image): """Detect rail segments using the API""" try: # Log original image dimensions width, height = image.size print(f"Detecting rail segments on image with dimensions: {width}x{height}") files = {"file": image_to_bytes(image)} response = requests.post( f"{API_URL}/detect/rail-segment", files=files, timeout=60 ) if response.status_code == 200: result = response.json() if "segments" in result: return result["segments"], base64_to_image(result["image_base64"]) else: return [], None else: print(f"API error: {response.status_code} - Image size was {width}x{height}") return [], None except Exception as e: print(f"Error in detect_rail_segments: {str(e)}") return [], None def extract_protection_area(first_frame): """Extract and cache protection area points using rail segment detection""" try: # Log original frame dimensions height, width = first_frame.shape[:2] print(f"Extracting protection area from frame with dimensions: {width}x{height}") # Convert frame to PIL Image without resizing first_frame_pil = Image.fromarray(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)) # Verify PIL image dimensions pil_width, pil_height = first_frame_pil.size print(f"PIL Image dimensions before API call: {pil_width}x{pil_height}") # Detect rail segments segments, segment_img = detect_rail_segments(first_frame_pil) if segments and len(segments) > 0: # Verify segment image dimensions if segment_img is not None: seg_height, seg_width = segment_img.shape[:2] print(f"Received segment image dimensions: {seg_width}x{seg_height}") # Only resize if dimensions don't match if (seg_width, seg_height) != (width, height): print(f"Resizing segment image from {seg_width}x{seg_height} to {width}x{height}") segment_img = cv2.resize(segment_img, (width, height), interpolation=cv2.INTER_LANCZOS4) # Store segments and image state.detected_segments = segments state.segment_image = segment_img # Create segment choices with more detailed information segment_choices = [] for i, segment in enumerate(segments): # Extract mask dimensions for verification mask_points = segment.get('mask', []) if mask_points: mask_x = [p[0] for p in mask_points] mask_y = [p[1] for p in mask_points] mask_width = max(mask_x) - min(mask_x) mask_height = max(mask_y) - min(mask_y) print(f"Segment {i+1} mask dimensions: {mask_width}x{mask_height}") choice_text = f"Segment {i+1} (Confidence: {segment['confidence']:.2f})" segment_choices.append(choice_text) state.selected_segments = segment_choices # Select all segments by default # Use the first segment's mask as protection area segment = segments[0] if 'mask' in segment and segment['mask']: mask_points = segment['mask'] # Convert to list of [x,y] points and ensure integer values mask_points = [[int(float(x)), int(float(y))] for x, y in mask_points] if len(mask_points) >= 3: # Need at least 3 points for a valid polygon state.cached_protection_area = mask_points # Convert segment image to RGB for display without resizing if segment_img is not None: display_img = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB) return True, "Protection area extracted successfully", display_img return False, "Invalid mask points in segment", None return False, "No valid rail segments detected", None except Exception as e: print(f"Error in extract_protection_area: {str(e)}") return False, f"Error extracting protection area: {str(e)}", None def get_segment_index(choice_text): """Extract segment index from choice text""" try: # Extract index from "Segment X (Confidence: Y)" format return int(choice_text.split()[1]) - 1 except: return -1 def update_object_tracking(objects_in_area): """Update object tracking with new detections""" try: current_tracked = set() # Keep track of objects seen in this frame current_warnings = [] # Collect warnings for this frame # Match new detections with existing tracked objects for obj in objects_in_area: try: if 'bbox' not in obj or 'class' not in obj: continue bbox = obj['bbox'] obj_class = obj['class'] is_in_red_zone = obj.get('in_protection_area', False) matched = False best_match_id = None best_similarity = 0.0 # Try to match with existing tracked objects using similarity method for obj_id, tracked in state.tracked_objects.items(): if tracked.class_name == obj_class: # Check if object was seen recently (within time window) if state.frame_count - tracked.last_seen <= state.time_window: similarity = tracked.get_similarity_with(bbox) # Use the best match above threshold if similarity > state.similarity_threshold and similarity > best_similarity: best_similarity = similarity best_match_id = obj_id # If good match found, update existing object if best_match_id is not None: tracked = state.tracked_objects[best_match_id] tracked.add_detection(bbox) tracked.last_seen = state.frame_count current_tracked.add(best_match_id) matched = True # Check red zone status and warnings warning_triggered = tracked.update_red_zone_status(is_in_red_zone, state.frame_count) if warning_triggered: warning_msg = f"āš ļø WARNING: {tracked.class_name} (ID: {tracked.id}) has been in red zone for {tracked.frames_in_red_zone} frames!" current_warnings.append(warning_msg) state.red_zone_warnings.append({ 'frame': state.frame_count, 'object_id': tracked.id, 'class': tracked.class_name, 'frames_in_zone': tracked.frames_in_red_zone, 'message': warning_msg }) # Check if object should be counted (only count objects that actually move through the zone) if not tracked.counted and tracked.has_movement() and is_in_red_zone: # Additional check: object should have been tracked for at least a few frames if len(tracked.trajectory) >= 3: tracked.counted = True state.red_zone_passed_objects[obj_class] += 1 # If no match found, create new tracked object if not matched: new_obj = TrackedObject(state.next_obj_id, obj_class, bbox) new_obj.last_seen = state.frame_count new_obj.first_seen = state.frame_count state.tracked_objects[state.next_obj_id] = new_obj current_tracked.add(state.next_obj_id) state.next_obj_id += 1 # Check red zone status for new object new_obj.update_red_zone_status(is_in_red_zone, state.frame_count) except Exception as e: continue # Update objects not seen in current frame for obj_id, tracked in state.tracked_objects.items(): if obj_id not in current_tracked: # Object not seen in current frame, update red zone status tracked.update_red_zone_status(False, state.frame_count) # Remove objects that haven't been seen for a while if state.frame_count > state.time_window: to_remove = [] for obj_id, tracked in state.tracked_objects.items(): if state.frame_count - tracked.last_seen > state.time_window * 2: # Remove after 2x time window to_remove.append(obj_id) for obj_id in to_remove: del state.tracked_objects[obj_id] # Store current warnings if current_warnings: print(f"Frame {state.frame_count} Warnings: {current_warnings}") except Exception as e: print(f"Error in update_object_tracking: {str(e)}") def get_red_zone_summary(): """Generate summary of objects that passed through red zone""" summary = [] if state.red_zone_passed_objects: summary.append("šŸ”“ RED ZONE PASSAGE SUMMARY:") total_objects = sum(state.red_zone_passed_objects.values()) summary.append(f"Total objects passed: {total_objects}") for obj_class, count in sorted(state.red_zone_passed_objects.items()): summary.append(f" • {obj_class}: {count}") # Add current objects in red zone current_in_zone = [] for obj_id, tracked in state.tracked_objects.items(): if tracked.frames_in_red_zone > 0: current_in_zone.append(f"{tracked.class_name} (ID: {tracked.id}, {tracked.frames_in_red_zone} frames)") if current_in_zone: summary.append("\n🚨 CURRENTLY IN RED ZONE:") for obj_info in current_in_zone: summary.append(f" • {obj_info}") # Add recent warnings recent_warnings = [w for w in state.red_zone_warnings if state.frame_count - w['frame'] <= 5] if recent_warnings: summary.append("\nāš ļø RECENT WARNINGS:") for warning in recent_warnings[-3:]: # Show last 3 warnings summary.append(f" • Frame {warning['frame']}: {warning['message']}") return "\n".join(summary) if summary else "No objects detected in red zone yet." def process_frame(frame, confidence): """Process a video frame using cached protection area""" try: protection_area = [] if state.selected_segments and state.detected_segments: for choice in state.selected_segments: idx = get_segment_index(choice) if 0 <= idx < len(state.detected_segments): segment = state.detected_segments[idx] if 'mask' in segment and segment['mask']: protection_area = segment['mask'] break elif len(state.protection_points) >= 3: protection_area = state.protection_points if not protection_area: return None, "Protection area not set. Please extract protection area first." # Ensure frame is valid if frame is None or frame.size == 0: return None, "Invalid frame" success, buffer = cv2.imencode('.png', frame) if not success: return None, "Failed to encode frame" files = { "file": ("frame.png", buffer.tobytes(), "image/png") } protection_area_json = json.dumps(protection_area) data = { "protection_area": protection_area_json, "confidence_threshold": str(confidence) } if state.previous_detections: data["previous_detections"] = json.dumps(state.previous_detections) try: response = requests.post( f"{API_URL}/detect/objects-and-redlight", files=files, data=data, timeout=60 ) if response.status_code == 200: result = response.json() if not result.get("success"): return None, f"API Error: {result.get('detail', 'Unknown error')}" result_data = result.get("result", {}) if not result_data: return None, "No result data received" red_light_info = result_data.get("red_light", {}) red_light_detected = red_light_info.get("detected", False) red_light_prob = red_light_info.get("probability", 0) img_base64 = result_data.get("image_base64") if not img_base64: return None, "No image data received from API" try: if ',' in img_base64: img_base64 = img_base64.split(',')[1] img_data = base64.b64decode(img_base64) nparr = np.frombuffer(img_data, np.uint8) processed_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if processed_img is None or processed_img.size == 0: return None, "Failed to decode image from API response" objects_in_area = [obj for obj in result_data.get("objects", []) if obj.get("in_protection_area", False) and 'bbox' in obj and 'class' in obj] # Update object tracking state.frame_count += 1 update_object_tracking(objects_in_area) # Cache detections for next frame state.previous_detections = objects_in_area processed_img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB) status = [] status.append(f"Red Light: {'YES' if red_light_detected else 'NO'} ({red_light_prob:.2f})") # Add enhanced red zone summary red_zone_summary = get_red_zone_summary() status.append(f"\n{red_zone_summary}") if objects_in_area: status.append("\nšŸ“Š CURRENT FRAME DETECTIONS:") for obj in objects_in_area: status.append(f" • {obj['class']} (confidence: {obj['confidence']:.2f})") # Add tracking statistics active_objects = len([obj for obj in state.tracked_objects.values() if state.frame_count - obj.last_seen <= 3]) status.append(f"\nšŸ“ˆ TRACKING STATS:") status.append(f" • Active tracked objects: {active_objects}") status.append(f" • Frame: {state.frame_count}") status.append(f" • Time window: {state.time_window} frames") status.append(f" • Similarity threshold: {state.similarity_threshold:.2f}") return processed_img_rgb, "\n".join(status) except Exception as e: return None, f"Error processing detection results: {str(e)}" else: error_detail = f"API Error: {response.status_code}" try: error_json = response.json() if 'detail' in error_json: error_detail += f" - {error_json['detail']}" except: error_detail += f" - {response.text}" return None, error_detail except requests.exceptions.Timeout: return None, "API request timed out" except requests.exceptions.ConnectionError: return None, "Could not connect to API server" except Exception as e: return None, f"API request failed: {str(e)}" except Exception as e: return None, f"Error processing frame: {str(e)}" def process_video(video, confidence=DEFAULT_CONFIDENCE, target_fps=1): """Stream processed frames in real-time using cached protection area""" detection_results = [] cap = cv2.VideoCapture(video) if not cap.isOpened(): yield None, "Error: Could not open video file" return total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) frame_interval = max(1, int(fps / target_fps)) frame_number = 0 try: while True: ret, frame = cap.read() if not ret: break frame_number += 1 if frame_number % frame_interval != 0: continue # Process frame and get results processed_frame, result = process_frame(frame, confidence) if processed_frame is not None: # Frame is already in RGB format from process_frame current_status = f"Processing frame {frame_number}/{total_frames}\n{result}" yield processed_frame, current_status else: current_status = f"Frame {frame_number}: {result}" yield None, current_status # Release resources cap.release() # Generate final summary final_summary = generate_final_summary() yield None, final_summary except Exception as e: yield None, f"Error processing video: {str(e)}" finally: cap.release() def generate_final_summary(): """Generate comprehensive final summary of video processing""" summary_lines = [] summary_lines.append("šŸŽ¬ VIDEO PROCESSING COMPLETE") summary_lines.append("=" * 50) # Processing statistics summary_lines.append(f"šŸ“Š PROCESSING STATISTICS:") summary_lines.append(f" • Total frames processed: {state.frame_count}") summary_lines.append(f" • Time window used: {state.time_window} frames") summary_lines.append(f" • Similarity threshold: {state.similarity_threshold:.2f}") # Red zone passage summary if state.red_zone_passed_objects: summary_lines.append(f"\nšŸ”“ RED ZONE PASSAGE SUMMARY:") total_passed = sum(state.red_zone_passed_objects.values()) summary_lines.append(f" • Total objects passed through red zone: {total_passed}") for obj_class, count in sorted(state.red_zone_passed_objects.items()): summary_lines.append(f" - {obj_class}: {count}") else: summary_lines.append(f"\nšŸ”“ RED ZONE PASSAGE SUMMARY:") summary_lines.append(f" • No objects detected passing through red zone") # Warning summary if state.red_zone_warnings: summary_lines.append(f"\nāš ļø WARNING SUMMARY:") summary_lines.append(f" • Total warnings generated: {len(state.red_zone_warnings)}") # Group warnings by object class warning_by_class = defaultdict(int) for warning in state.red_zone_warnings: warning_by_class[warning['class']] += 1 for obj_class, count in sorted(warning_by_class.items()): summary_lines.append(f" - {obj_class}: {count} warnings") # Show last few warnings if len(state.red_zone_warnings) > 0: summary_lines.append(f"\n šŸ“‹ Recent warnings:") for warning in state.red_zone_warnings[-5:]: # Last 5 warnings summary_lines.append(f" - Frame {warning['frame']}: {warning['class']} (ID: {warning['object_id']}) - {warning['frames_in_zone']} frames in zone") else: summary_lines.append(f"\nāš ļø WARNING SUMMARY:") summary_lines.append(f" • No warnings generated (no objects stayed in red zone > 3 frames)") # Active tracking summary total_tracked = len(state.tracked_objects) if total_tracked > 0: summary_lines.append(f"\nšŸ“ˆ OBJECT TRACKING SUMMARY:") summary_lines.append(f" • Total unique objects tracked: {total_tracked}") # Group by class objects_by_class = defaultdict(int) for obj in state.tracked_objects.values(): objects_by_class[obj.class_name] += 1 for obj_class, count in sorted(objects_by_class.items()): summary_lines.append(f" - {obj_class}: {count}") summary_lines.append("\nāœ… Processing completed successfully!") return "\n".join(summary_lines) def extract_area_from_video(video): if video is None: return None, "Please upload a video", gr.update(choices=[], value=[], visible=False) cap = cv2.VideoCapture(video) ret, frame = cap.read() cap.release() if not ret: return None, "Could not read video frame", gr.update(choices=[], value=[], visible=False) success, message, segment_img = extract_protection_area(frame) if success and segment_img is not None: # Convert segment image to RGB for display segment_img_rgb = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB) # Create segment choices segment_choices = [f"Segment {i+1} (Confidence: {segment['confidence']:.2f})" for i, segment in enumerate(state.detected_segments)] return segment_img_rgb, message, gr.update(choices=segment_choices, value=segment_choices, visible=True) return None, message, gr.update(choices=[], value=[], visible=False) def update_selected_segments(selected): if selected is None: selected = [] state.selected_segments = selected return gr.update() def process_video_wrapper(video, confidence=DEFAULT_CONFIDENCE, target_fps=1, time_window=10, similarity_threshold=0.6): """Wrapper around process_video to handle full-size video processing""" if video is None: yield None, "Please upload a video" return # Reset tracking state and update parameters state.reset_tracking() state.time_window = time_window state.similarity_threshold = similarity_threshold protection_area = [] if state.selected_segments and state.detected_segments: for choice in state.selected_segments: idx = get_segment_index(choice) if 0 <= idx < len(state.detected_segments): segment = state.detected_segments[idx] if 'mask' in segment and segment['mask']: protection_area = segment['mask'] break elif len(state.protection_points) >= 3: protection_area = state.protection_points if not protection_area: yield None, "Please extract protection area first" return try: yield None, f"šŸš€ Starting video processing...\nāš™ļø Time window: {time_window} frames\nāš™ļø Similarity threshold: {similarity_threshold:.2f}" for frame, status in process_video(video, confidence, target_fps): yield frame, status except Exception as e: yield None, f"Error processing video: {str(e)}" # Update the Gradio interface with gr.Blocks(title="Enhanced Rail Traffic Monitor") as demo: gr.Markdown(""" # Enhanced Rail Traffic Monitoring System ## Features: - **Smart Object Tracking**: Uses similarity method to track objects across frames - **Red Zone Monitoring**: Counts objects passing through the red zone - **Warning System**: Alerts when objects stay in red zone for more than 3 frames - **Configurable Parameters**: Adjust time window and similarity threshold ## Setup Instructions: **Method 1 (Manual Protection Area):** 1. Click 4 points on the image to define protection area 2. Click "Reset Points" to start over **Method 2 (Automatic Detection):** 1. Click "Extract Protection Area" to automatically detect rail segments **Processing:** 3. Adjust detection confidence, processing frame rate, time window, and similarity threshold 4. Click "Process Video" to analyze The system will show real-time results including: - Objects currently in red zone - Total count of objects that passed through - Warnings for objects staying too long in red zone - Tracking statistics """) with gr.Row(): with gr.Column(): video_input = gr.Video( label="Input Video" ) with gr.Row(): confidence = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_CONFIDENCE, label="Detection Confidence Threshold", info="Minimum confidence for object detection" ) fps_slider = gr.Slider( minimum=1, maximum=30, value=1, step=1, label="Processing Frame Rate (FPS)", info="Frames per second to process" ) with gr.Row(): time_window_slider = gr.Slider( minimum=5, maximum=50, value=10, step=1, label="Time Window (frames)", info="Number of frames to consider for object similarity" ) similarity_threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.6, step=0.05, label="Similarity Threshold", info="Threshold for considering objects as the same (higher = stricter)" ) with gr.Column(): preview_image = gr.Image( label="Click to Select Protection Area (Original Size)", interactive=True, show_label=True ) # Add segment selection dropdown segment_dropdown = gr.Dropdown( label="Selected Segments", choices=[], multiselect=True, interactive=True, visible=False, value=[] ) with gr.Row(): reset_btn = gr.Button("Reset Points") extract_btn = gr.Button("Extract Protection Area") process_btn = gr.Button("šŸš€ Process Video") with gr.Row(): video_output = gr.Image( label="Live Processing Output", streaming=True, interactive=False, show_label=True, container=True, show_download_button=True ) text_output = gr.Textbox( label="Detection Results & Red Zone Summary", lines=15, max_lines=20, show_copy_button=True ) # Handle video upload to populate preview video_input.change( fn=update_preview, inputs=[video_input], outputs=[preview_image, segment_dropdown] ) extract_btn.click( fn=extract_area_from_video, inputs=[video_input], outputs=[preview_image, text_output, segment_dropdown] ) segment_dropdown.change( fn=update_selected_segments, inputs=[segment_dropdown], outputs=[segment_dropdown] ) process_btn.click( fn=process_video_wrapper, inputs=[video_input, confidence, fps_slider, time_window_slider, similarity_threshold_slider], outputs=[video_output, text_output] ) # Add click event handler preview_image.select( fn=handle_image_click, inputs=[preview_image], outputs=[preview_image, text_output] ) # Add reset button handler reset_btn.click( fn=reset_points, inputs=[], outputs=[preview_image, text_output] ) if __name__ == "__main__": demo.queue().launch()