|
import os |
|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from shapely.geometry import Polygon, box as shapely_box |
|
import gradio as gr |
|
from PIL import Image |
|
import time |
|
import spaces |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
|
def extract_class_0_coordinates(filename): |
|
class_0_coordinates = [] |
|
with open(filename, 'r') as file: |
|
for line in file: |
|
parts = line.strip().split() |
|
if len(parts) == 0: |
|
continue |
|
if parts[0] == '0': |
|
coordinates = [float(x) for x in parts[1:]] |
|
class_0_coordinates.extend(coordinates) |
|
return class_0_coordinates |
|
|
|
def read_yolo_boxes(file_path): |
|
boxes = [] |
|
with open(file_path, 'r') as f: |
|
for line in f: |
|
parts = line.strip().split() |
|
class_name = COCO_CLASSES[int(parts[0])] |
|
x, y, w, h = map(float, parts[1:5]) |
|
boxes.append((class_name, x, y, w, h)) |
|
return boxes |
|
|
|
def yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height): |
|
x1 = int((x_center - width / 2) * img_width) |
|
y1 = int((y_center - height / 2) * img_height) |
|
x2 = int((x_center + width / 2) * img_width) |
|
y2 = int((y_center + height / 2) * img_height) |
|
return x1, y1, x2, y2 |
|
|
|
def convert_segment_to_pixel(segment, img_width, img_height): |
|
return [(int(x * img_width), int(y * img_height)) for x, y in zip(segment[::2], segment[1::2])] |
|
|
|
def box_segment_relationship(yolo_box, segment, img_width, img_height, threshold): |
|
class_id, x_center, y_center, width, height = yolo_box |
|
x1, y1, x2, y2 = yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height) |
|
pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) |
|
segment_polygon = Polygon(pixel_segment) |
|
box_polygon = shapely_box(x1, y1, x2, y2) |
|
|
|
if box_polygon.intersects(segment_polygon): |
|
return "intersecting" |
|
elif box_polygon.distance(segment_polygon) <= threshold: |
|
return "obstructed" |
|
else: |
|
return "not touching" |
|
|
|
|
|
COCO_CLASSES = [ |
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', |
|
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
|
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', |
|
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', |
|
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', |
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', |
|
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', |
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', |
|
'hair drier', 'toothbrush' |
|
] |
|
|
|
|
|
def detect_rail(image): |
|
|
|
image = np.array(image) |
|
|
|
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
|
|
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
else: |
|
|
|
image_bgr = image |
|
|
|
temp_image_path = "temp_image_rail.jpg" |
|
cv2.imwrite(temp_image_path, image_bgr) |
|
|
|
os.system(f"python segment/predict.py --source {temp_image_path} --img 640 --device cpu --weights models/segment/best-2.pt --name yolov9_c_640_detect --exist-ok --save-txt") |
|
|
|
label_path = 'runs/predict-seg/yolov9_c_640_detect/labels/temp_image_rail.txt' |
|
|
|
segment = extract_class_0_coordinates(label_path) |
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
ax.imshow(image) |
|
|
|
img_height, img_width = image.shape[:2] |
|
pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) |
|
ax.plot([x for x, _ in pixel_segment] + [pixel_segment[0][0]], |
|
[y for _, y in pixel_segment] + [pixel_segment[0][1]], |
|
'g-', linewidth=2, label='Rail Zone') |
|
|
|
ax.legend() |
|
ax.axis('off') |
|
plt.tight_layout() |
|
|
|
os.remove(temp_image_path) |
|
os.remove(label_path) |
|
|
|
return fig, segment, "Rail detection completed. You can now upload a video for object detection." |
|
|
|
def create_sample_video(output_path, duration=10, fps=30, width=640, height=480): |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
for _ in range(duration * fps): |
|
frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) |
|
out.write(frame) |
|
|
|
out.release() |
|
return output_path |
|
|
|
def process_video(video_path, rail_segment, frame_skip=15): |
|
if not os.path.exists(video_path): |
|
return None, f"Error: Video file not found at {video_path}" |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None, "Error: Could not open video file." |
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
output_dir = 'output_videos' |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
timestamp = int(time.time()) |
|
output_filename = f'processed_video_{timestamp}.mp4' |
|
output_path = os.path.join(output_dir, output_filename) |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, fps // frame_skip, (width, height)) |
|
|
|
frame_count = 0 |
|
processed_count = 0 |
|
threshold = 10 |
|
|
|
obstructed_frames = 0 |
|
all_detections = [] |
|
|
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame_count += 1 |
|
if frame_count % frame_skip != 0: |
|
continue |
|
|
|
processed_count += 1 |
|
|
|
|
|
temp_frame_path = f"temp_frame_{processed_count:04d}.jpg" |
|
cv2.imwrite(temp_frame_path, frame) |
|
|
|
|
|
os.system(f"python detect.py --source {temp_frame_path} --img 640 --device cpu --weights models/detect/yolov9-s-converted.pt --name yolov9_c_640_detect --exist-ok --save-txt") |
|
|
|
|
|
label_path = f'runs/detect/yolov9_c_640_detect/labels/temp_frame_{processed_count:04d}.txt' |
|
yolo_boxes = read_yolo_boxes(label_path) |
|
all_detections.append(yolo_boxes) |
|
|
|
os.remove(temp_frame_path) |
|
os.remove(label_path) |
|
|
|
print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})") |
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
frame_count = 0 |
|
processed_count = 0 |
|
|
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame_count += 1 |
|
if frame_count % frame_skip != 0: |
|
continue |
|
|
|
processed_count += 1 |
|
|
|
|
|
pixel_segment = convert_segment_to_pixel(rail_segment, width, height) |
|
cv2.polylines(frame, [np.array(pixel_segment)], True, (0, 255, 0), 2) |
|
|
|
|
|
frame_obstructed = False |
|
for box in all_detections[processed_count - 1]: |
|
class_name, x, y, w, h = box |
|
relationship = box_segment_relationship((0, x, y, w, h), rail_segment, width, height, threshold) |
|
x1, y1, x2, y2 = yolo_to_pixel_coords(x, y, w, h, width, height) |
|
|
|
if relationship == "intersecting": |
|
color = (0, 0, 255) |
|
frame_obstructed = True |
|
elif relationship == "obstructed": |
|
color = (0, 255, 255) |
|
frame_obstructed = True |
|
else: |
|
color = (255, 0, 0) |
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) |
|
cv2.putText(frame, f"{class_name} ({relationship})", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) |
|
|
|
if frame_obstructed: |
|
obstructed_frames += 1 |
|
|
|
out.write(frame) |
|
|
|
print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})") |
|
|
|
cap.release() |
|
out.release() |
|
|
|
if processed_count == 0: |
|
return None, "Error: No frames were processed." |
|
|
|
obstruction_percentage = (obstructed_frames / processed_count) * 100 |
|
summary = f"Video processing completed. Processed {processed_count} out of {total_frames} frames.\n\n" |
|
summary += f"Obstruction Summary:\n" |
|
summary += f"Total processed frames: {processed_count}\n" |
|
summary += f"Frames with obstructions: {obstructed_frames}\n" |
|
summary += f"Percentage of frames with obstructions: {obstruction_percentage:.2f}%\n" |
|
summary += f"Output video saved as: {output_path}" |
|
|
|
return output_path, summary |
|
|
|
|
|
class TwoStepDetection: |
|
def __init__(self): |
|
self.rail_segment = None |
|
|
|
def rail_detection(self, rail_input): |
|
if rail_input is None: |
|
return None, "Please upload an image for rail detection." |
|
|
|
rail_fig, self.rail_segment, message = detect_rail(rail_input) |
|
return rail_fig, message |
|
|
|
def object_detection(self, video_input, frame_skip=15): |
|
if self.rail_segment is None: |
|
return None, "Please complete rail detection first." |
|
|
|
if video_input is None: |
|
|
|
sample_video_path = "sample_video.mp4" |
|
create_sample_video(sample_video_path) |
|
video_input = sample_video_path |
|
|
|
video_output, processing_message = process_video(video_input, self.rail_segment, frame_skip) |
|
if video_output is None: |
|
return None, processing_message |
|
|
|
return video_output, processing_message |
|
|
|
|
|
detector = TwoStepDetection() |
|
|
|
with gr.Blocks(title="Two-Step Train Obstruction Detection") as iface: |
|
gr.Markdown("# Two-Step Train Obstruction Detection") |
|
gr.Markdown("Step 1: Upload an image to detect the rail. Step 2: Upload a video to detect obstructions.") |
|
|
|
with gr.Tab("Step 1: Rail Detection"): |
|
rail_input = gr.Image(type="numpy", label="Upload image for rail detection") |
|
rail_output = gr.Plot(label="Rail Detection Result") |
|
rail_message = gr.Textbox(label="Rail Detection Message") |
|
rail_button = gr.Button("Detect Rail") |
|
|
|
with gr.Tab("Step 2: Object Detection"): |
|
video_input = gr.Video(label="Upload video for object detection") |
|
frame_skip = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Frame Skip Rate") |
|
video_output = gr.Video(label="Object Detection Result") |
|
object_message = gr.Textbox(label="Object Detection Results") |
|
object_button = gr.Button("Detect Objects") |
|
|
|
rail_button.click(detector.rail_detection, inputs=rail_input, outputs=[rail_output, rail_message]) |
|
object_button.click(detector.object_detection, inputs=[video_input, frame_skip], outputs=[video_output, object_message]) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |