railVDO / app.py
nn
Upload 247 files
8f87556 verified
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, box as shapely_box
import gradio as gr
from PIL import Image
import time
import spaces
@spaces.GPU
# Utility functions
def extract_class_0_coordinates(filename):
class_0_coordinates = []
with open(filename, 'r') as file:
for line in file:
parts = line.strip().split()
if len(parts) == 0:
continue
if parts[0] == '0':
coordinates = [float(x) for x in parts[1:]]
class_0_coordinates.extend(coordinates)
return class_0_coordinates
def read_yolo_boxes(file_path):
boxes = []
with open(file_path, 'r') as f:
for line in f:
parts = line.strip().split()
class_name = COCO_CLASSES[int(parts[0])]
x, y, w, h = map(float, parts[1:5])
boxes.append((class_name, x, y, w, h))
return boxes
def yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height):
x1 = int((x_center - width / 2) * img_width)
y1 = int((y_center - height / 2) * img_height)
x2 = int((x_center + width / 2) * img_width)
y2 = int((y_center + height / 2) * img_height)
return x1, y1, x2, y2
def convert_segment_to_pixel(segment, img_width, img_height):
return [(int(x * img_width), int(y * img_height)) for x, y in zip(segment[::2], segment[1::2])]
def box_segment_relationship(yolo_box, segment, img_width, img_height, threshold):
class_id, x_center, y_center, width, height = yolo_box
x1, y1, x2, y2 = yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height)
pixel_segment = convert_segment_to_pixel(segment, img_width, img_height)
segment_polygon = Polygon(pixel_segment)
box_polygon = shapely_box(x1, y1, x2, y2)
if box_polygon.intersects(segment_polygon):
return "intersecting"
elif box_polygon.distance(segment_polygon) <= threshold:
return "obstructed"
else:
return "not touching"
# COCO classes
COCO_CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
# Detection functions
def detect_rail(image):
# Convert PIL image to numpy array
image = np.array(image)
# Check if the image is RGB (3 channels)
if len(image.shape) == 3 and image.shape[2] == 3:
# Convert RGB to BGR (OpenCV format)
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
# If not RGB, just use the image as is (assuming it's already in a format OpenCV can handle)
image_bgr = image
temp_image_path = "temp_image_rail.jpg"
cv2.imwrite(temp_image_path, image_bgr)
os.system(f"python segment/predict.py --source {temp_image_path} --img 640 --device cpu --weights models/segment/best-2.pt --name yolov9_c_640_detect --exist-ok --save-txt")
label_path = 'runs/predict-seg/yolov9_c_640_detect/labels/temp_image_rail.txt'
segment = extract_class_0_coordinates(label_path)
fig, ax = plt.subplots(figsize=(12, 8))
ax.imshow(image) # Use the original image for display
img_height, img_width = image.shape[:2]
pixel_segment = convert_segment_to_pixel(segment, img_width, img_height)
ax.plot([x for x, _ in pixel_segment] + [pixel_segment[0][0]],
[y for _, y in pixel_segment] + [pixel_segment[0][1]],
'g-', linewidth=2, label='Rail Zone')
ax.legend()
ax.axis('off')
plt.tight_layout()
os.remove(temp_image_path)
os.remove(label_path)
return fig, segment, "Rail detection completed. You can now upload a video for object detection."
def create_sample_video(output_path, duration=10, fps=30, width=640, height=480):
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for _ in range(duration * fps):
frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
out.write(frame)
out.release()
return output_path
def process_video(video_path, rail_segment, frame_skip=15):
if not os.path.exists(video_path):
return None, f"Error: Video file not found at {video_path}"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "Error: Could not open video file."
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create output directory if it doesn't exist
output_dir = 'output_videos'
os.makedirs(output_dir, exist_ok=True)
# Generate a unique filename based on timestamp
timestamp = int(time.time())
output_filename = f'processed_video_{timestamp}.mp4'
output_path = os.path.join(output_dir, output_filename)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps // frame_skip, (width, height))
frame_count = 0
processed_count = 0
threshold = 10 # Set threshold (in pixels) for obstruction detection
obstructed_frames = 0
all_detections = []
# First pass: Detect objects in all frames
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
processed_count += 1
# Save the frame as a temporary image
temp_frame_path = f"temp_frame_{processed_count:04d}.jpg"
cv2.imwrite(temp_frame_path, frame)
# Run object detection on the frame
os.system(f"python detect.py --source {temp_frame_path} --img 640 --device cpu --weights models/detect/yolov9-s-converted.pt --name yolov9_c_640_detect --exist-ok --save-txt")
# Read detection results
label_path = f'runs/detect/yolov9_c_640_detect/labels/temp_frame_{processed_count:04d}.txt'
yolo_boxes = read_yolo_boxes(label_path)
all_detections.append(yolo_boxes)
os.remove(temp_frame_path)
os.remove(label_path)
print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})")
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
frame_count = 0
processed_count = 0
# Second pass: Check for obstructions and create output video
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
processed_count += 1
# Draw rail segment
pixel_segment = convert_segment_to_pixel(rail_segment, width, height)
cv2.polylines(frame, [np.array(pixel_segment)], True, (0, 255, 0), 2)
# Check for obstructions and draw bounding boxes
frame_obstructed = False
for box in all_detections[processed_count - 1]:
class_name, x, y, w, h = box
relationship = box_segment_relationship((0, x, y, w, h), rail_segment, width, height, threshold)
x1, y1, x2, y2 = yolo_to_pixel_coords(x, y, w, h, width, height)
if relationship == "intersecting":
color = (0, 0, 255) # Red for intersecting
frame_obstructed = True
elif relationship == "obstructed":
color = (0, 255, 255) # Yellow for obstructed
frame_obstructed = True
else:
color = (255, 0, 0) # Blue for not touching
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, f"{class_name} ({relationship})", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
if frame_obstructed:
obstructed_frames += 1
out.write(frame)
print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})")
cap.release()
out.release()
if processed_count == 0:
return None, "Error: No frames were processed."
obstruction_percentage = (obstructed_frames / processed_count) * 100
summary = f"Video processing completed. Processed {processed_count} out of {total_frames} frames.\n\n"
summary += f"Obstruction Summary:\n"
summary += f"Total processed frames: {processed_count}\n"
summary += f"Frames with obstructions: {obstructed_frames}\n"
summary += f"Percentage of frames with obstructions: {obstruction_percentage:.2f}%\n"
summary += f"Output video saved as: {output_path}"
return output_path, summary
# Gradio interface
class TwoStepDetection:
def __init__(self):
self.rail_segment = None
def rail_detection(self, rail_input):
if rail_input is None:
return None, "Please upload an image for rail detection."
rail_fig, self.rail_segment, message = detect_rail(rail_input)
return rail_fig, message
def object_detection(self, video_input, frame_skip=15):
if self.rail_segment is None:
return None, "Please complete rail detection first."
if video_input is None:
# Create a sample video if none is provided
sample_video_path = "sample_video.mp4"
create_sample_video(sample_video_path)
video_input = sample_video_path
video_output, processing_message = process_video(video_input, self.rail_segment, frame_skip)
if video_output is None:
return None, processing_message
return video_output, processing_message
# Create Gradio interface
detector = TwoStepDetection()
with gr.Blocks(title="Two-Step Train Obstruction Detection") as iface:
gr.Markdown("# Two-Step Train Obstruction Detection")
gr.Markdown("Step 1: Upload an image to detect the rail. Step 2: Upload a video to detect obstructions.")
with gr.Tab("Step 1: Rail Detection"):
rail_input = gr.Image(type="numpy", label="Upload image for rail detection")
rail_output = gr.Plot(label="Rail Detection Result")
rail_message = gr.Textbox(label="Rail Detection Message")
rail_button = gr.Button("Detect Rail")
with gr.Tab("Step 2: Object Detection"):
video_input = gr.Video(label="Upload video for object detection")
frame_skip = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Frame Skip Rate")
video_output = gr.Video(label="Object Detection Result")
object_message = gr.Textbox(label="Object Detection Results")
object_button = gr.Button("Detect Objects")
rail_button.click(detector.rail_detection, inputs=rail_input, outputs=[rail_output, rail_message])
object_button.click(detector.object_detection, inputs=[video_input, frame_skip], outputs=[video_output, object_message])
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()