Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import cv2 | |
import numpy as np | |
import tempfile | |
import gradio as gr | |
from ultralytics import YOLO # Now used for Yolov8spose with integrated tracker & pose estimation | |
import torch | |
css = """ | |
/* This targets the container of the ImageEditor by its element ID */ | |
#my_image_editor { | |
height: 1000px !important; /* Change 600px to your desired height */ | |
} | |
/* You might also need to target inner elements if the component uses nested divs */ | |
#my_image_editor .image-editor-canvas { | |
height: 1000px !important; | |
} | |
""" | |
# Get the directory where the current script is located. | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
EXAMPLE_VIDEO = os.path.join(BASE_DIR, "examples", "faint.mp4") | |
# ---------------------------- | |
# Helper: Extract red polygon from editor drawing (alert zone) | |
def extract_polygon_from_editor(editor_image, epsilon_ratio=0.01): | |
if editor_image is None: | |
return None, "❌ No alert zone drawing provided." | |
composite = editor_image.get("composite") | |
original = editor_image.get("background") | |
if composite is None or original is None: | |
return None, "⚠️ Please load the first frame and add a drawing layer with the zone." | |
composite_np = np.array(composite) | |
# Detect red strokes (assume vivid red) | |
r_channel = composite_np[:, :, 0] | |
g_channel = composite_np[:, :, 1] | |
b_channel = composite_np[:, :, 2] | |
red_mask = (r_channel > 150) & (g_channel < 100) & (b_channel < 100) | |
binary_mask = red_mask.astype(np.uint8) * 255 | |
# Find contours and approximate the largest to a polygon | |
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
return None, "⚠️ No visible drawing found. Please draw your alert zone with red strokes." | |
largest_contour = max(contours, key=cv2.contourArea) | |
epsilon = epsilon_ratio * cv2.arcLength(largest_contour, True) | |
polygon = cv2.approxPolyDP(largest_contour, epsilon, True) | |
if polygon is None or len(polygon) < 3: | |
return None, "⚠️ Polygon extraction failed. Try drawing a clearer alert zone." | |
# Reshape polygon to a list of (x, y) coordinates. | |
polygon_coords = polygon.reshape(-1, 2).tolist() | |
return polygon_coords, f"✅ Alert zone polygon with {len(polygon_coords)} points extracted." | |
# ---------------------------- | |
# Helper: Draw preview image with the approximated alert zone drawn on the background. | |
def preview_zone_on_frame(editor_image, epsilon_ratio=0.01): | |
background = editor_image.get("background") | |
if background is None: | |
return None, "⚠️ Background frame is missing from the editor image." | |
# Convert the background to a NumPy array copy. | |
preview = np.array(background).copy() | |
polygon, msg = extract_polygon_from_editor(editor_image, epsilon_ratio) | |
if polygon is None: | |
return None, msg | |
pts = np.array(polygon, np.int32).reshape((-1, 1, 2)) | |
# Draw the alert zone in red. | |
cv2.polylines(preview, [pts], isClosed=True, color=(0, 0, 255), thickness=5) | |
return preview, f"Preview generated. {msg}" | |
# ---------------------------- | |
# Helper: Compute Euclidean distance | |
def compute_distance(p1, p2): | |
return np.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2) | |
# Helper: Bottom-center of a bounding box. | |
def bottom_center(box): | |
x1, y1, x2, y2 = box # [x1, y1, x2, y2] | |
return ((x1 + x2) / 2, y2) | |
# Helper: Draw multiline text on frame. | |
def draw_multiline_text(frame, text_lines, org, font=cv2.FONT_HERSHEY_SIMPLEX, | |
font_scale=0.4, text_color=(255,255,255), bg_color=(50,50,50), | |
thickness=1, line_spacing=2): | |
x, y = org | |
for line in text_lines: | |
(text_w, text_h), baseline = cv2.getTextSize(line, font, font_scale, thickness) | |
cv2.rectangle(frame, (x, y - text_h - baseline), (x + text_w, y + baseline), bg_color, -1) | |
cv2.putText(frame, line, (x, y), font, font_scale, text_color, thickness, cv2.LINE_AA) | |
y += text_h + baseline + line_spacing | |
# ---------------------------- | |
# Helper: Determine if a person is lying based on integrated keypoints. | |
def is_lying_from_keypoints(flat_keypoints, box_height): | |
""" | |
Expects flat_keypoints as a list or array that can be reshaped into (num_keypoints, 3). | |
For example, if there are 17 keypoints, the length should be 51. | |
Uses keypoints 5 (left shoulder), 6 (right shoulder), 11 (left hip), 12 (right hip). | |
""" | |
try: | |
kp = np.array(flat_keypoints).reshape(-1, 3) | |
left_shoulder_y = kp[5][1] | |
right_shoulder_y = kp[6][1] | |
left_hip_y = kp[11][1] | |
right_hip_y = kp[12][1] | |
shoulder_y = (left_shoulder_y + right_shoulder_y) / 2.0 | |
hip_y = (left_hip_y + right_hip_y) / 2.0 | |
vertical_diff = abs(hip_y - shoulder_y) | |
if vertical_diff < (box_height * 0.25): | |
return True | |
except Exception as e: | |
print("Keypoint processing error:", e) | |
return False | |
# ---------------------------- | |
# Main function: Process video with faint detection only within alert zone | |
def process_video_with_zone(video_file, threshold_secs, velocity_threshold, editor_image, epsilon_ratio): | |
# Extract the alert zone polygon from the editor image. | |
alert_zone, zone_msg = extract_polygon_from_editor(editor_image, epsilon_ratio) | |
if alert_zone is None: | |
return zone_msg, None | |
cap = cv2.VideoCapture(video_file if isinstance(video_file, str) else video_file.name) | |
if not cap.isOpened(): | |
return "Error opening video file.", None | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out_path = os.path.join(tempfile.gettempdir(), "output_alert.mp4") | |
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) | |
# ---------------------------- | |
# Initialize the unified Yolov8spose model. | |
# This model is expected to provide bounding boxes, integrated keypoints, and tracking IDs. | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.xpu.is_available(): | |
device = "xpu" | |
else: | |
device = "cpu" | |
yolov8spose_model = YOLO('yolo11s-pose.pt', task='pose') | |
yolov8spose_model.to(device) | |
yolov8spose_model.eval() | |
# Dictionaries to track static (motionless) timings based on integrated track IDs. | |
lying_start_times = {} # For marking when a person first appears static. | |
velocity_static_info = {} # For velocity-based detection: stores (last bottom-center, frame index). | |
frame_index = 0 | |
threshold_frames = threshold_secs * fps # Convert threshold seconds to frames | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_index += 1 | |
# Draw the alert zone on the frame in red. | |
pts = np.array(alert_zone, np.int32).reshape((-1, 1, 2)) | |
cv2.polylines(frame, [pts], isClosed=True, color=(0, 0, 255), thickness=2) | |
results = yolov8spose_model(frame)[0] | |
boxes = results.boxes | |
kpts = results.keypoints.data | |
for i in range(len(boxes)): | |
box = boxes[i].xyxy[0].cpu().numpy() | |
x1, y1, x2, y2 = box.astype(int) | |
conf = boxes[i].conf[0].item() | |
cls = int(boxes[i].cls[0].item()) | |
track_id = int(boxes[i].id[0].item()) if boxes[i].id is not None else -1 | |
if cls != 0 or conf < 0.5: | |
continue | |
flat_keypoints = kpts[i].cpu().numpy().flatten().tolist() | |
kp = np.array(flat_keypoints).reshape(-1, 3) | |
for pair in [ | |
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10), | |
(11, 12), (11, 13), (13, 15), (12, 14), (14, 16), | |
(5, 11), (6, 12) | |
]: | |
i1, j1 = pair | |
if kp[i1][2] > 0.3 and kp[j1][2] > 0.3: | |
pt1 = (int(kp[i1][0]), int(kp[i1][1])) | |
pt2 = (int(kp[j1][0]), int(kp[j1][1])) | |
cv2.line(frame, pt1, pt2, (0, 255, 255), 2) | |
if len(kp) > 12: | |
pt = ((kp[11][0] + kp[12][0]) / 2, (kp[11][1] + kp[12][1]) / 2) | |
else: | |
continue | |
pt = (float(pt[0]), float(pt[1])) | |
in_alert_zone = cv2.pointPolygonTest(np.array(alert_zone, np.int32), pt, False) >= 0 | |
cv2.circle(frame, (int(pt[0]), int(pt[1])), 5, (0, 0, 255), -1) | |
if not in_alert_zone: | |
status = "Outside Zone" | |
color = (200, 200, 200) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
draw_multiline_text(frame, [f"ID {track_id}: {status}"], (x1, max(y1-10, 0))) | |
continue | |
aspect_ratio = (x2 - x1) / float(y2 - y1) if (y2 - y1) > 0 else 0 | |
base_lying = aspect_ratio > 1.5 and y2 > height * 0.5 | |
integrated_lying = is_lying_from_keypoints(flat_keypoints, y2 - y1) | |
pose_static = base_lying and integrated_lying | |
current_bottom = bottom_center((x1, y1, x2, y2)) | |
if len(kp) > 12: | |
pt = ((kp[11][0] + kp[12][0]) / 2, (kp[11][1] + kp[12][1]) / 2) | |
else: | |
continue | |
pt = (float(pt[0]), float(pt[1])) # mid-hip | |
in_alert_zone = cv2.pointPolygonTest(np.array(alert_zone, np.int32), pt, False) >= 0 | |
cv2.circle(frame, (int(pt[0]), int(pt[1])), 5, (0, 0, 255), -1) # mid-hip marker | |
cv2.circle(frame, (int(current_bottom[0]), int(current_bottom[1])), 3, (255, 0, 0), -1) # bottom center marker | |
if not in_alert_zone: | |
status = "Outside Zone" | |
color = (200, 200, 200) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
draw_multiline_text(frame, [f"ID {track_id}: {status}"], (x1, max(y1-10, 0))) | |
continue | |
alpha = 0.8 | |
if track_id not in velocity_static_info: | |
velocity_static_info[track_id] = (current_bottom, frame_index) | |
smoothed = current_bottom | |
velocity_val = 0.0 | |
velocity_static = False | |
else: | |
prev_pt, _ = velocity_static_info[track_id] | |
smoothed = alpha * np.array(prev_pt) + (1 - alpha) * np.array(current_bottom) | |
velocity_static_info[track_id] = (smoothed.tolist(), frame_index) | |
distance = compute_distance(smoothed, prev_pt) | |
velocity_val = distance * fps | |
velocity_static = distance < velocity_threshold | |
is_static = pose_static or velocity_static | |
if is_static: | |
if track_id not in lying_start_times: | |
lying_start_times[track_id] = frame_index | |
duration_frames = frame_index - lying_start_times[track_id] | |
else: | |
lying_start_times.pop(track_id, None) | |
duration_frames = 0 | |
if duration_frames >= threshold_frames: | |
status = f"FAINTED ({duration_frames/fps:.1f}s)" | |
color = (0, 0, 255) | |
elif is_static: | |
status = f"Static ({duration_frames/fps:.1f}s)" | |
color = (0, 255, 255) | |
else: | |
status = "Upright" | |
color = (0, 255, 0) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
draw_multiline_text(frame, [f"ID {track_id}: {status}"], (x1, max(y1-10, 0))) | |
vel_text = f"Vel: {velocity_val:.1f} px/s" | |
text_offset = 15 | |
(vt_w, vt_h), vt_baseline = cv2.getTextSize(vel_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1) | |
vel_org = (int(pt[0] - vt_w / 2), int(pt[1] + text_offset + vt_h)) | |
cv2.rectangle(frame, (vel_org[0], vel_org[1] - vt_h - vt_baseline), | |
(vel_org[0] + vt_w, vel_org[1] + vt_baseline), (50,50,50), -1) | |
cv2.putText(frame, vel_text, vel_org, cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255,255,255), 1, cv2.LINE_AA) | |
out.write(frame) | |
cap.release() | |
out.release() | |
final_msg = f"{zone_msg}\nProcessed video saved to: {out_path}" | |
return final_msg, out_path | |
# ---------------------------- | |
# Gradio Interface Construction | |
with gr.Blocks(css=css) as demo: | |
gr.HTML("<style>body { margin: 0; padding: 0; }</style>") | |
gr.Markdown("## 🚨 Faint Detection in a User-Defined Alert Zone") | |
gr.Markdown( | |
""" | |
**Instructions:** | |
1. Upload a video. | |
2. Click **Load First Frame to Editor** to extract a frame. | |
3. Add a drawing layer and draw your alert zone using red strokes. | |
4. Click **Preview Alert Zone** to verify the polygon approximation. | |
5. Adjust the polygon approximation if needed. | |
6. Process the video; detection will only occur within the alert zone. | |
""" | |
) | |
with gr.Tab("Load Video & Define Alert Zone"): | |
video_input = gr.Video(label="Upload Video", format="mp4") | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[EXAMPLE_VIDEO] | |
], | |
inputs=[video_input], | |
label="Try Example Video" | |
) | |
load_frame_btn = gr.Button("Load First Frame to Editor") | |
# Assign an elem_id for targeting via CSS. | |
frame_editor = gr.ImageEditor( | |
label="Draw Alert Zone on this frame (use red brush)", | |
type="numpy", | |
elem_id="my_image_editor" | |
) | |
preview_button = gr.Button("Preview Alert Zone") | |
polygon_info = gr.Textbox(label="Alert Zone Polygon Info", lines=3) | |
preview_image = gr.Image(label="Alert Zone Preview (Polygon Overlay)", type="numpy") | |
epsilon_slider = gr.Slider( | |
label="Polygon Approximation (ε)", minimum=0.001, maximum=0.05, value=0.01, step=0.001 | |
) | |
with gr.Tab("Process Video"): | |
motion_threshold_slider = gr.Slider(1, 600, value=3, step=1, label="Motionless Duration Threshold (seconds)") | |
velocity_threshold_slider = gr.Slider(0.5, 20.0, value=3.0, step=0.5, label="Velocity Threshold (pixels)") | |
output_text = gr.Textbox(label="Processing Info", lines=6) | |
video_output = gr.Video(label="Processed Video", format="mp4") | |
# Function to load and display the first frame from the video. | |
def load_frame(video_file): | |
cap = cv2.VideoCapture(video_file if isinstance(video_file, str) else video_file.name) | |
if not cap.isOpened(): | |
return None, "❌ Failed to open video." | |
ret, frame = cap.read() | |
cap.release() | |
if not ret or frame is None or frame.size == 0: | |
return None, "❌ Failed to extract frame from video." | |
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), "Frame loaded successfully. Now draw your alert zone." | |
load_frame_btn.click(fn=load_frame, inputs=video_input, outputs=[frame_editor, polygon_info]) | |
# Button to preview alert zone polygon as both coordinates and a preview image. | |
def preview_alert_zone(editor_image, epsilon): | |
poly, msg = extract_polygon_from_editor(editor_image, epsilon) | |
preview, preview_msg = preview_zone_on_frame(editor_image, epsilon) | |
if preview is None: | |
return msg, None | |
return f"Extracted Polygon Coordinates:\n{poly}\n{msg}", preview | |
preview_button.click(fn=preview_alert_zone, inputs=[frame_editor, epsilon_slider], outputs=[polygon_info, preview_image]) | |
# Process the video with faint detection within the alert zone. | |
process_btn = gr.Button("Process Video in Alert Zone") | |
process_btn.click( | |
fn=process_video_with_zone, | |
inputs=[video_input, motion_threshold_slider, velocity_threshold_slider, frame_editor, epsilon_slider], | |
outputs=[output_text, video_output] | |
) | |
demo.launch() | |