Luigi's picture
increase image editor zone
abb9c26
import spaces
import os
import cv2
import numpy as np
import tempfile
import gradio as gr
from ultralytics import YOLO # Now used for Yolov8spose with integrated tracker & pose estimation
import torch
css = """
/* This targets the container of the ImageEditor by its element ID */
#my_image_editor {
height: 1000px !important; /* Change 600px to your desired height */
}
/* You might also need to target inner elements if the component uses nested divs */
#my_image_editor .image-editor-canvas {
height: 1000px !important;
}
"""
# Get the directory where the current script is located.
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
EXAMPLE_VIDEO = os.path.join(BASE_DIR, "examples", "faint.mp4")
# ----------------------------
# Helper: Extract red polygon from editor drawing (alert zone)
def extract_polygon_from_editor(editor_image, epsilon_ratio=0.01):
if editor_image is None:
return None, "❌ No alert zone drawing provided."
composite = editor_image.get("composite")
original = editor_image.get("background")
if composite is None or original is None:
return None, "⚠️ Please load the first frame and add a drawing layer with the zone."
composite_np = np.array(composite)
# Detect red strokes (assume vivid red)
r_channel = composite_np[:, :, 0]
g_channel = composite_np[:, :, 1]
b_channel = composite_np[:, :, 2]
red_mask = (r_channel > 150) & (g_channel < 100) & (b_channel < 100)
binary_mask = red_mask.astype(np.uint8) * 255
# Find contours and approximate the largest to a polygon
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None, "⚠️ No visible drawing found. Please draw your alert zone with red strokes."
largest_contour = max(contours, key=cv2.contourArea)
epsilon = epsilon_ratio * cv2.arcLength(largest_contour, True)
polygon = cv2.approxPolyDP(largest_contour, epsilon, True)
if polygon is None or len(polygon) < 3:
return None, "⚠️ Polygon extraction failed. Try drawing a clearer alert zone."
# Reshape polygon to a list of (x, y) coordinates.
polygon_coords = polygon.reshape(-1, 2).tolist()
return polygon_coords, f"✅ Alert zone polygon with {len(polygon_coords)} points extracted."
# ----------------------------
# Helper: Draw preview image with the approximated alert zone drawn on the background.
def preview_zone_on_frame(editor_image, epsilon_ratio=0.01):
background = editor_image.get("background")
if background is None:
return None, "⚠️ Background frame is missing from the editor image."
# Convert the background to a NumPy array copy.
preview = np.array(background).copy()
polygon, msg = extract_polygon_from_editor(editor_image, epsilon_ratio)
if polygon is None:
return None, msg
pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
# Draw the alert zone in red.
cv2.polylines(preview, [pts], isClosed=True, color=(0, 0, 255), thickness=5)
return preview, f"Preview generated. {msg}"
# ----------------------------
# Helper: Compute Euclidean distance
def compute_distance(p1, p2):
return np.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
# Helper: Bottom-center of a bounding box.
def bottom_center(box):
x1, y1, x2, y2 = box # [x1, y1, x2, y2]
return ((x1 + x2) / 2, y2)
# Helper: Draw multiline text on frame.
def draw_multiline_text(frame, text_lines, org, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=0.4, text_color=(255,255,255), bg_color=(50,50,50),
thickness=1, line_spacing=2):
x, y = org
for line in text_lines:
(text_w, text_h), baseline = cv2.getTextSize(line, font, font_scale, thickness)
cv2.rectangle(frame, (x, y - text_h - baseline), (x + text_w, y + baseline), bg_color, -1)
cv2.putText(frame, line, (x, y), font, font_scale, text_color, thickness, cv2.LINE_AA)
y += text_h + baseline + line_spacing
# ----------------------------
# Helper: Determine if a person is lying based on integrated keypoints.
def is_lying_from_keypoints(flat_keypoints, box_height):
"""
Expects flat_keypoints as a list or array that can be reshaped into (num_keypoints, 3).
For example, if there are 17 keypoints, the length should be 51.
Uses keypoints 5 (left shoulder), 6 (right shoulder), 11 (left hip), 12 (right hip).
"""
try:
kp = np.array(flat_keypoints).reshape(-1, 3)
left_shoulder_y = kp[5][1]
right_shoulder_y = kp[6][1]
left_hip_y = kp[11][1]
right_hip_y = kp[12][1]
shoulder_y = (left_shoulder_y + right_shoulder_y) / 2.0
hip_y = (left_hip_y + right_hip_y) / 2.0
vertical_diff = abs(hip_y - shoulder_y)
if vertical_diff < (box_height * 0.25):
return True
except Exception as e:
print("Keypoint processing error:", e)
return False
# ----------------------------
# Main function: Process video with faint detection only within alert zone
@spaces.GPU
@torch.no_grad()
def process_video_with_zone(video_file, threshold_secs, velocity_threshold, editor_image, epsilon_ratio):
# Extract the alert zone polygon from the editor image.
alert_zone, zone_msg = extract_polygon_from_editor(editor_image, epsilon_ratio)
if alert_zone is None:
return zone_msg, None
cap = cv2.VideoCapture(video_file if isinstance(video_file, str) else video_file.name)
if not cap.isOpened():
return "Error opening video file.", None
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out_path = os.path.join(tempfile.gettempdir(), "output_alert.mp4")
out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
# ----------------------------
# Initialize the unified Yolov8spose model.
# This model is expected to provide bounding boxes, integrated keypoints, and tracking IDs.
if torch.cuda.is_available():
device = "cuda"
elif torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
yolov8spose_model = YOLO('yolo11s-pose.pt', task='pose')
yolov8spose_model.to(device)
yolov8spose_model.eval()
# Dictionaries to track static (motionless) timings based on integrated track IDs.
lying_start_times = {} # For marking when a person first appears static.
velocity_static_info = {} # For velocity-based detection: stores (last bottom-center, frame index).
frame_index = 0
threshold_frames = threshold_secs * fps # Convert threshold seconds to frames
while True:
ret, frame = cap.read()
if not ret:
break
frame_index += 1
# Draw the alert zone on the frame in red.
pts = np.array(alert_zone, np.int32).reshape((-1, 1, 2))
cv2.polylines(frame, [pts], isClosed=True, color=(0, 0, 255), thickness=2)
results = yolov8spose_model(frame)[0]
boxes = results.boxes
kpts = results.keypoints.data
for i in range(len(boxes)):
box = boxes[i].xyxy[0].cpu().numpy()
x1, y1, x2, y2 = box.astype(int)
conf = boxes[i].conf[0].item()
cls = int(boxes[i].cls[0].item())
track_id = int(boxes[i].id[0].item()) if boxes[i].id is not None else -1
if cls != 0 or conf < 0.5:
continue
flat_keypoints = kpts[i].cpu().numpy().flatten().tolist()
kp = np.array(flat_keypoints).reshape(-1, 3)
for pair in [
(5, 6), (5, 7), (7, 9), (6, 8), (8, 10),
(11, 12), (11, 13), (13, 15), (12, 14), (14, 16),
(5, 11), (6, 12)
]:
i1, j1 = pair
if kp[i1][2] > 0.3 and kp[j1][2] > 0.3:
pt1 = (int(kp[i1][0]), int(kp[i1][1]))
pt2 = (int(kp[j1][0]), int(kp[j1][1]))
cv2.line(frame, pt1, pt2, (0, 255, 255), 2)
if len(kp) > 12:
pt = ((kp[11][0] + kp[12][0]) / 2, (kp[11][1] + kp[12][1]) / 2)
else:
continue
pt = (float(pt[0]), float(pt[1]))
in_alert_zone = cv2.pointPolygonTest(np.array(alert_zone, np.int32), pt, False) >= 0
cv2.circle(frame, (int(pt[0]), int(pt[1])), 5, (0, 0, 255), -1)
if not in_alert_zone:
status = "Outside Zone"
color = (200, 200, 200)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
draw_multiline_text(frame, [f"ID {track_id}: {status}"], (x1, max(y1-10, 0)))
continue
aspect_ratio = (x2 - x1) / float(y2 - y1) if (y2 - y1) > 0 else 0
base_lying = aspect_ratio > 1.5 and y2 > height * 0.5
integrated_lying = is_lying_from_keypoints(flat_keypoints, y2 - y1)
pose_static = base_lying and integrated_lying
current_bottom = bottom_center((x1, y1, x2, y2))
if len(kp) > 12:
pt = ((kp[11][0] + kp[12][0]) / 2, (kp[11][1] + kp[12][1]) / 2)
else:
continue
pt = (float(pt[0]), float(pt[1])) # mid-hip
in_alert_zone = cv2.pointPolygonTest(np.array(alert_zone, np.int32), pt, False) >= 0
cv2.circle(frame, (int(pt[0]), int(pt[1])), 5, (0, 0, 255), -1) # mid-hip marker
cv2.circle(frame, (int(current_bottom[0]), int(current_bottom[1])), 3, (255, 0, 0), -1) # bottom center marker
if not in_alert_zone:
status = "Outside Zone"
color = (200, 200, 200)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
draw_multiline_text(frame, [f"ID {track_id}: {status}"], (x1, max(y1-10, 0)))
continue
alpha = 0.8
if track_id not in velocity_static_info:
velocity_static_info[track_id] = (current_bottom, frame_index)
smoothed = current_bottom
velocity_val = 0.0
velocity_static = False
else:
prev_pt, _ = velocity_static_info[track_id]
smoothed = alpha * np.array(prev_pt) + (1 - alpha) * np.array(current_bottom)
velocity_static_info[track_id] = (smoothed.tolist(), frame_index)
distance = compute_distance(smoothed, prev_pt)
velocity_val = distance * fps
velocity_static = distance < velocity_threshold
is_static = pose_static or velocity_static
if is_static:
if track_id not in lying_start_times:
lying_start_times[track_id] = frame_index
duration_frames = frame_index - lying_start_times[track_id]
else:
lying_start_times.pop(track_id, None)
duration_frames = 0
if duration_frames >= threshold_frames:
status = f"FAINTED ({duration_frames/fps:.1f}s)"
color = (0, 0, 255)
elif is_static:
status = f"Static ({duration_frames/fps:.1f}s)"
color = (0, 255, 255)
else:
status = "Upright"
color = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
draw_multiline_text(frame, [f"ID {track_id}: {status}"], (x1, max(y1-10, 0)))
vel_text = f"Vel: {velocity_val:.1f} px/s"
text_offset = 15
(vt_w, vt_h), vt_baseline = cv2.getTextSize(vel_text, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
vel_org = (int(pt[0] - vt_w / 2), int(pt[1] + text_offset + vt_h))
cv2.rectangle(frame, (vel_org[0], vel_org[1] - vt_h - vt_baseline),
(vel_org[0] + vt_w, vel_org[1] + vt_baseline), (50,50,50), -1)
cv2.putText(frame, vel_text, vel_org, cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255,255,255), 1, cv2.LINE_AA)
out.write(frame)
cap.release()
out.release()
final_msg = f"{zone_msg}\nProcessed video saved to: {out_path}"
return final_msg, out_path
# ----------------------------
# Gradio Interface Construction
with gr.Blocks(css=css) as demo:
gr.HTML("<style>body { margin: 0; padding: 0; }</style>")
gr.Markdown("## 🚨 Faint Detection in a User-Defined Alert Zone")
gr.Markdown(
"""
**Instructions:**
1. Upload a video.
2. Click **Load First Frame to Editor** to extract a frame.
3. Add a drawing layer and draw your alert zone using red strokes.
4. Click **Preview Alert Zone** to verify the polygon approximation.
5. Adjust the polygon approximation if needed.
6. Process the video; detection will only occur within the alert zone.
"""
)
with gr.Tab("Load Video & Define Alert Zone"):
video_input = gr.Video(label="Upload Video", format="mp4")
with gr.Row():
gr.Examples(
examples=[
[EXAMPLE_VIDEO]
],
inputs=[video_input],
label="Try Example Video"
)
load_frame_btn = gr.Button("Load First Frame to Editor")
# Assign an elem_id for targeting via CSS.
frame_editor = gr.ImageEditor(
label="Draw Alert Zone on this frame (use red brush)",
type="numpy",
elem_id="my_image_editor"
)
preview_button = gr.Button("Preview Alert Zone")
polygon_info = gr.Textbox(label="Alert Zone Polygon Info", lines=3)
preview_image = gr.Image(label="Alert Zone Preview (Polygon Overlay)", type="numpy")
epsilon_slider = gr.Slider(
label="Polygon Approximation (ε)", minimum=0.001, maximum=0.05, value=0.01, step=0.001
)
with gr.Tab("Process Video"):
motion_threshold_slider = gr.Slider(1, 600, value=3, step=1, label="Motionless Duration Threshold (seconds)")
velocity_threshold_slider = gr.Slider(0.5, 20.0, value=3.0, step=0.5, label="Velocity Threshold (pixels)")
output_text = gr.Textbox(label="Processing Info", lines=6)
video_output = gr.Video(label="Processed Video", format="mp4")
# Function to load and display the first frame from the video.
def load_frame(video_file):
cap = cv2.VideoCapture(video_file if isinstance(video_file, str) else video_file.name)
if not cap.isOpened():
return None, "❌ Failed to open video."
ret, frame = cap.read()
cap.release()
if not ret or frame is None or frame.size == 0:
return None, "❌ Failed to extract frame from video."
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), "Frame loaded successfully. Now draw your alert zone."
load_frame_btn.click(fn=load_frame, inputs=video_input, outputs=[frame_editor, polygon_info])
# Button to preview alert zone polygon as both coordinates and a preview image.
def preview_alert_zone(editor_image, epsilon):
poly, msg = extract_polygon_from_editor(editor_image, epsilon)
preview, preview_msg = preview_zone_on_frame(editor_image, epsilon)
if preview is None:
return msg, None
return f"Extracted Polygon Coordinates:\n{poly}\n{msg}", preview
preview_button.click(fn=preview_alert_zone, inputs=[frame_editor, epsilon_slider], outputs=[polygon_info, preview_image])
# Process the video with faint detection within the alert zone.
process_btn = gr.Button("Process Video in Alert Zone")
process_btn.click(
fn=process_video_with_zone,
inputs=[video_input, motion_threshold_slider, velocity_threshold_slider, frame_editor, epsilon_slider],
outputs=[output_text, video_output]
)
demo.launch()