import cv2 import mediapipe as mp import numpy as np import tflite_runtime.interpreter as tflite import time from collections import deque import gradio as gr import os import shutil # สำหรับ copy ไฟล์ example # --- Configuration (ปรับให้เหมาะกับ Gradio) --- MODEL_PATH = 'fall_detection_transformer.tflite' #MODEL_PATH = 'fall_detection_transformer_60.tflite' INPUT_TIMESTEPS = 30 FALL_CONFIDENCE_THRESHOLD = 0.90 MIN_KEYPOINT_CONFIDENCE_FOR_NORMALIZATION = 0.3 mp_pose = mp.solutions.pose pose_complexity = 0 # ลด complexity เพื่อความเร็วบน Spaces, ลอง 0 หรือ 1 use_static_image_mode = False # สำหรับวิดีโอไฟล์ จะถูก override เป็น True ใน process_video FALL_EVENT_COOLDOWN = 10 # ----- 0. KEYPOINT DEFINITIONS (เหมือนเดิม) ----- KEYPOINT_NAMES_ORIGINAL = [ 'Nose', 'Left Eye Inner', 'Left Eye', 'Left Eye Outer', 'Right Eye Inner', 'Right Eye', 'Right Eye Outer', 'Left Ear', 'Right Ear', 'Mouth Left', 'Mouth Right', 'Left Shoulder', 'Right Shoulder', 'Left Elbow', 'Right Elbow', 'Left Wrist', 'Right Wrist', 'Left Pinky', 'Right Pinky', 'Left Index', 'Right Index', 'Left Thumb', 'Right Thumb', 'Left Hip', 'Right Hip', 'Left Knee', 'Right Knee', 'Left Ankle', 'Right Ankle', 'Left Heel', 'Right Heel', 'Left Foot Index', 'Right Foot Index' ] MEDIAPIPE_TO_YOUR_KEYPOINTS_MAPPING = { mp_pose.PoseLandmark.NOSE: 'Nose', mp_pose.PoseLandmark.LEFT_EYE: 'Left Eye', mp_pose.PoseLandmark.RIGHT_EYE: 'Right Eye', mp_pose.PoseLandmark.LEFT_EAR: 'Left Ear', mp_pose.PoseLandmark.RIGHT_EAR: 'Right Ear', mp_pose.PoseLandmark.LEFT_SHOULDER: 'Left Shoulder', mp_pose.PoseLandmark.RIGHT_SHOULDER: 'Right Shoulder', mp_pose.PoseLandmark.LEFT_ELBOW: 'Left Elbow', mp_pose.PoseLandmark.RIGHT_ELBOW: 'Right Elbow', mp_pose.PoseLandmark.LEFT_WRIST: 'Left Wrist', mp_pose.PoseLandmark.RIGHT_WRIST: 'Right Wrist', mp_pose.PoseLandmark.LEFT_HIP: 'Left Hip', mp_pose.PoseLandmark.RIGHT_HIP: 'Right Hip', mp_pose.PoseLandmark.LEFT_KNEE: 'Left Knee', mp_pose.PoseLandmark.RIGHT_KNEE: 'Right Knee', mp_pose.PoseLandmark.LEFT_ANKLE: 'Left Ankle', mp_pose.PoseLandmark.RIGHT_ANKLE: 'Right Ankle' } YOUR_KEYPOINT_NAMES_TRAINING = [ 'Nose', 'Left Eye', 'Right Eye', 'Left Ear', 'Right Ear', 'Left Shoulder', 'Right Shoulder', 'Left Elbow', 'Right Elbow', 'Left Wrist', 'Right Wrist', 'Left Hip', 'Right Hip', 'Left Knee', 'Right Knee', 'Left Ankle', 'Right Ankle' ] SORTED_YOUR_KEYPOINT_NAMES = sorted(YOUR_KEYPOINT_NAMES_TRAINING) KEYPOINT_DICT_TRAINING = {name: i for i, name in enumerate(SORTED_YOUR_KEYPOINT_NAMES)} NUM_KEYPOINTS_TRAINING = len(KEYPOINT_DICT_TRAINING) NUM_FEATURES = NUM_KEYPOINTS_TRAINING * 3 print("--- Initializing Keypoint Definitions for Gradio App ---") print(f"NUM_FEATURES for model input: {NUM_FEATURES}") # --------------------------------------------------------------- # --- Load TFLite Model --- try: interpreter = tflite.Interpreter(model_path=MODEL_PATH) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print(f"TFLite Model Loaded: {MODEL_PATH}") model_expected_shape = tuple(input_details[0]['shape']) if model_expected_shape[2] != NUM_FEATURES or model_expected_shape[1] != INPUT_TIMESTEPS: print(f"FATAL ERROR: Model's expected input shape features/timesteps " f"({model_expected_shape[1]},{model_expected_shape[2]}) " f"does not match configured ({INPUT_TIMESTEPS},{NUM_FEATURES}).") exit() except Exception as e: print(f"Error loading TFLite model: {e}") exit() # --- Helper Functions (get_kpt_indices, normalize_skeleton_frame, extract_and_normalize_features - เหมือนเดิม) --- def get_kpt_indices_training_order(keypoint_name): if keypoint_name not in KEYPOINT_DICT_TRAINING: raise ValueError(f"Keypoint '{keypoint_name}' not found in KEYPOINT_DICT_TRAINING. Available: {list(KEYPOINT_DICT_TRAINING.keys())}") kp_idx = KEYPOINT_DICT_TRAINING[keypoint_name] return kp_idx * 3, kp_idx * 3 + 1, kp_idx * 3 + 2 def normalize_skeleton_frame(frame_features_sorted, min_confidence=MIN_KEYPOINT_CONFIDENCE_FOR_NORMALIZATION): normalized_frame = np.copy(frame_features_sorted) ref_kp_names = {'ls': 'Left Shoulder', 'rs': 'Right Shoulder', 'lh': 'Left Hip', 'rh': 'Right Hip'} try: ls_x_idx, ls_y_idx, ls_c_idx = get_kpt_indices_training_order(ref_kp_names['ls']) rs_x_idx, rs_y_idx, rs_c_idx = get_kpt_indices_training_order(ref_kp_names['rs']) lh_x_idx, lh_y_idx, lh_c_idx = get_kpt_indices_training_order(ref_kp_names['lh']) rh_x_idx, rh_y_idx, rh_c_idx = get_kpt_indices_training_order(ref_kp_names['rh']) except ValueError as e: print(f"Warning in normalize_skeleton_frame (get_kpt_indices): {e}") return frame_features_sorted ls_x, ls_y, ls_c = frame_features_sorted[ls_x_idx], frame_features_sorted[ls_y_idx], frame_features_sorted[ls_c_idx] rs_x, rs_y, rs_c = frame_features_sorted[rs_x_idx], frame_features_sorted[rs_y_idx], frame_features_sorted[rs_c_idx] lh_x, lh_y, lh_c = frame_features_sorted[lh_x_idx], frame_features_sorted[lh_y_idx], frame_features_sorted[lh_c_idx] rh_x, rh_y, rh_c = frame_features_sorted[rh_x_idx], frame_features_sorted[rh_y_idx], frame_features_sorted[rh_c_idx] mid_shoulder_x, mid_shoulder_y = np.nan, np.nan valid_ls, valid_rs = ls_c > min_confidence, rs_c > min_confidence if valid_ls and valid_rs: mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 elif valid_ls: mid_shoulder_x, mid_shoulder_y = ls_x, ls_y elif valid_rs: mid_shoulder_x, mid_shoulder_y = rs_x, rs_y mid_hip_x, mid_hip_y = np.nan, np.nan valid_lh, valid_rh = lh_c > min_confidence, rh_c > min_confidence if valid_lh and valid_rh: mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 elif valid_lh: mid_hip_x, mid_hip_y = lh_x, lh_y elif valid_rh: mid_hip_x, mid_hip_y = rh_x, rh_y if np.isnan(mid_hip_x) or np.isnan(mid_hip_y): return frame_features_sorted reference_height = np.nan if not np.isnan(mid_shoulder_y) and not np.isnan(mid_hip_y): reference_height = np.abs(mid_shoulder_y - mid_hip_y) perform_scaling = not (np.isnan(reference_height) or reference_height < 1e-5) for kp_name_sorted in SORTED_YOUR_KEYPOINT_NAMES: try: x_col, y_col, _ = get_kpt_indices_training_order(kp_name_sorted) normalized_frame[x_col] -= mid_hip_x normalized_frame[y_col] -= mid_hip_y if perform_scaling: normalized_frame[x_col] /= reference_height normalized_frame[y_col] /= reference_height except ValueError: # Should not happen if kp_name_sorted is from SORTED_YOUR_KEYPOINT_NAMES pass return normalized_frame def extract_and_normalize_features(pose_results): frame_features_sorted = np.zeros(NUM_FEATURES, dtype=np.float32) if pose_results.pose_landmarks: landmarks = pose_results.pose_landmarks.landmark for mp_landmark_enum, your_kp_name in MEDIAPIPE_TO_YOUR_KEYPOINTS_MAPPING.items(): if your_kp_name in KEYPOINT_DICT_TRAINING: try: lm = landmarks[mp_landmark_enum.value] x_idx, y_idx, c_idx = get_kpt_indices_training_order(your_kp_name) frame_features_sorted[x_idx], frame_features_sorted[y_idx], frame_features_sorted[c_idx] = lm.x, lm.y, lm.visibility except (IndexError, ValueError) as e: print(f"Warning in extract_and_normalize_features for {your_kp_name}: {e}") pass normalized_features = normalize_skeleton_frame(frame_features_sorted.copy()) return normalized_features # ------------------------------------------------------------------------------------------------------------------- # --- Function to process uploaded video for Gradio --- def process_video_for_gradio(uploaded_video_path_temp): if uploaded_video_path_temp is None: return None, "Please upload a video file." print(f"Gradio provided temp video path: {uploaded_video_path_temp}") base_name = os.path.basename(uploaded_video_path_temp) # สร้าง path ที่ unique มากขึ้นสำหรับไฟล์ที่ copy มา timestamp_str = str(int(time.time() * 1000)) # เพิ่ม timestamp เพื่อความ unique local_video_path = os.path.join(os.getcwd(), f"{timestamp_str}_{base_name}") try: print(f"Copying video from {uploaded_video_path_temp} to {local_video_path}") shutil.copy2(uploaded_video_path_temp, local_video_path) print(f"Video copied successfully to {local_video_path}") except Exception as e: error_msg = f"Error copying video file: {e}\nTemp path: {uploaded_video_path_temp}" print(error_msg); return None, error_msg local_feature_sequence = deque(maxlen=INPUT_TIMESTEPS) local_last_fall_event_time = 0 # ใช้ local_last_fall_event_time_sec เพื่อความชัดเจนว่าเป็นหน่วยวินาทีของวิดีโอ cap = cv2.VideoCapture(local_video_path) if not cap.isOpened(): error_msg = f"Error: OpenCV cannot open video file at copied path: {local_video_path}" if os.path.exists(local_video_path): print(f"File size of '{local_video_path}': {os.path.getsize(local_video_path)} bytes") else: print(f"File '{local_video_path}' does not exist after copy attempt.") if os.path.exists(local_video_path): os.remove(local_video_path) # Cleanup return None, error_msg fps = cap.get(cv2.CAP_PROP_FPS) if fps == 0 or np.isnan(fps) or fps < 1: fps = 25.0 # Default FPS, ensure it's float processed_frames_list = [] overall_status_updates = [] with mp_pose.Pose( static_image_mode=True, model_complexity=pose_complexity, smooth_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose: frame_count = 0 while cap.isOpened(): success, original_bgr_frame = cap.read() # อ่าน frame มาเป็น BGR if not success: break frame_count += 1 # *** START: การแก้ไขเรื่องสีและการวาด *** # สร้างสำเนาของ BGR frame สำหรับการวาดผลลัพธ์ frame_for_display = original_bgr_frame.copy() # 1. แปลงเป็น RGB เฉพาะตอนส่งให้ MediaPipe image_rgb_for_mediapipe = cv2.cvtColor(original_bgr_frame, cv2.COLOR_BGR2RGB) image_rgb_for_mediapipe.flags.writeable = False results = pose.process(image_rgb_for_mediapipe) # image_rgb_for_mediapipe.flags.writeable = True # ไม่จำเป็นแล้ว # 2. Extract and Normalize Features current_features = extract_and_normalize_features(results) local_feature_sequence.append(current_features) # ... (ส่วนการทำนายผล prediction เหมือนเดิม) ... current_status_text_for_log = f"Frame {frame_count}: Collecting..." # สำหรับ log prediction_label = "no_fall" display_confidence_value = 0.0 if len(local_feature_sequence) == INPUT_TIMESTEPS: model_input_data = np.array(local_feature_sequence, dtype=np.float32) model_input_data = np.expand_dims(model_input_data, axis=0) try: interpreter.set_tensor(input_details[0]['index'], model_input_data) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) prediction_probability_fall = output_data[0][0] if prediction_probability_fall >= FALL_CONFIDENCE_THRESHOLD: prediction_label = "fall" display_confidence_value = prediction_probability_fall else: prediction_label = "no_fall" display_confidence_value = 1.0 - prediction_probability_fall current_status_text_for_log = f"Frame {frame_count}: {prediction_label.upper()} (Conf: {display_confidence_value:.2f})" current_video_time_sec = frame_count / fps if prediction_label == "fall": if (current_video_time_sec - local_last_fall_event_time) > FALL_EVENT_COOLDOWN: # ใช้ local_last_fall_event_time fall_message = f"Frame {frame_count} (~{current_video_time_sec:.1f}s): FALL DETECTED! (Conf: {prediction_probability_fall:.2f})" print(fall_message) overall_status_updates.append(fall_message) local_last_fall_event_time = current_video_time_sec # อัปเดตเวลา except Exception as e: print(f"Frame {frame_count}: Error during prediction: {e}") current_status_text_for_log = f"Frame {frame_count}: Prediction Error" display_confidence_value = 0.0 # อัปเดต overall_status_updates โดยใช้ current_status_text_for_log if "FALL DETECTED" not in current_status_text_for_log and \ (frame_count % int(fps*1) == 0 or (len(local_feature_sequence) == INPUT_TIMESTEPS and frame_count == INPUT_TIMESTEPS) or frame_count ==1) : if "Collecting..." not in current_status_text_for_log or frame_count == 1 : overall_status_updates.append(current_status_text_for_log) # 3. วาด Landmarks (ถ้ามี) บน frame_for_display (BGR) if results.pose_landmarks: # เพื่อให้ได้สี default ของ MediaPipe ที่ถูกต้องที่สุด, เราจะวาดบนสำเนา RGB ชั่วคราว # แล้วค่อยแปลงกลับมาเป็น BGR เพื่อใส่ใน frame_for_display temp_rgb_to_draw_landmarks = cv2.cvtColor(original_bgr_frame, cv2.COLOR_BGR2RGB).copy() mp.solutions.drawing_utils.draw_landmarks( temp_rgb_to_draw_landmarks, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=mp.solutions.drawing_styles.get_default_pose_landmarks_style() ) # ตอนนี้ frame_for_display ยังเป็น BGR ดั้งเดิม, เราจะเอา temp_rgb_to_draw_landmarks ที่วาดแล้ว # แปลงกลับเป็น BGR แล้วใช้เป็น frame_for_display ใหม่ frame_for_display = cv2.cvtColor(temp_rgb_to_draw_landmarks, cv2.COLOR_RGB2BGR) # ถ้าไม่มี landmarks, frame_for_display จะยังคงเป็น original_bgr_frame.copy() # 4. วาด Text บน frame_for_display (BGR) ทางขวามือ font_face = cv2.FONT_HERSHEY_DUPLEX font_scale_status = 0.6 thickness_status = 1 font_scale_alert = 1 thickness_alert = 2 padding = 30 # ระยะห่างจากขอบ text_to_show_on_frame = f"{prediction_label.upper()} (Conf: {display_confidence_value:.2f})" if "Collecting" in current_status_text_for_log or "Error" in current_status_text_for_log: # ใช้ current_status_text_for_log text_to_show_on_frame = current_status_text_for_log.split(': ')[-1] (text_w, text_h), _ = cv2.getTextSize(text_to_show_on_frame, font_face, font_scale_status, thickness_status) text_x_status = frame_for_display.shape[1] - text_w - padding text_y_status = padding + text_h status_color_bgr = (255, 255, 255) # เขียว (BGR) current_video_time_sec_for_alert_check = frame_count / fps if prediction_label == "fall" and not (current_video_time_sec_for_alert_check - local_last_fall_event_time < FALL_EVENT_COOLDOWN): status_color_bgr = (0, 165, 255) # สีส้ม (BGR) if "Error" in text_to_show_on_frame: status_color_bgr = (0,0,255) # สีแดง (BGR) cv2.putText(frame_for_display, text_to_show_on_frame, (text_x_status, text_y_status), font_face, font_scale_status, status_color_bgr, thickness_status, cv2.LINE_AA) if prediction_label == "fall" and (current_video_time_sec_for_alert_check - local_last_fall_event_time < FALL_EVENT_COOLDOWN): alert_text = "FALL DETECTED!" (alert_w, alert_h), _ = cv2.getTextSize(alert_text, font_face, font_scale_alert, thickness_alert) alert_x_pos = frame_for_display.shape[1] - alert_w - padding alert_y_pos = text_y_status + alert_h + padding // 2 cv2.putText(frame_for_display, alert_text, (alert_x_pos, alert_y_pos), font_face, font_scale_alert, (0, 0, 255), thickness_alert, cv2.LINE_AA) # สีแดง (BGR) # *** END *** processed_frames_list.append(frame_for_display) # เพิ่ม BGR frame ที่วาดแล้ว cap.release() if not processed_frames_list: if os.path.exists(local_video_path): try: os.remove(local_video_path); print(f"Cleaned up temp copied file: {local_video_path}") except Exception as e: print(f"Could not remove temp copied file {local_video_path} after no frames: {e}") return None, "No frames processed. Video might be empty or unreadable after copy." output_temp_video_path = f"processed_gradio_output_{timestamp_str}.mp4" height, width, _ = processed_frames_list[0].shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_temp_video_path, fourcc, fps, (width, height)) for frame_out_bgr in processed_frames_list: video_writer.write(frame_out_bgr) video_writer.release() print(f"Processed video saved to: {output_temp_video_path}") summary_text = "Recent Events / Status:\n" + "\n".join(overall_status_updates[-15:]) if os.path.exists(local_video_path): try: os.remove(local_video_path); print(f"Cleaned up temp copied file: {local_video_path}") except Exception as e: print(f"Could not remove temp copied file {local_video_path}: {e}") return output_temp_video_path, summary_text # --- สร้าง Gradio Interface --- # กำหนด list ของชื่อไฟล์ตัวอย่างของคุณ example_filenames = [ "fall_example_1.mp4", # <<<< แก้ไขชื่อไฟล์ตามที่คุณใช้ "fall_example_2.mp4", # <<<< แก้ไขชื่อไฟล์ตามที่คุณใช้ "fall_example_3.mp4", # <<<< แก้ไขชื่อไฟล์ตามที่คุณใช้ "fall_example_4.mp4" # <<<< แก้ไขชื่อไฟล์ตามที่คุณใช้ ] examples_list_for_gradio = [] for filename in example_filenames: # ตรวจสอบว่าไฟล์ example มีอยู่ใน root directory ของ repo จริงๆ if os.path.exists(filename): # Gradio examples ต้องการแค่ชื่อไฟล์ (ถ้าอยู่ใน root) examples_list_for_gradio.append([filename]) # Gradio ต้องการ list ของ list print(f"Info: Example file '{filename}' found and added.") else: print(f"Warning: Example file '{filename}' not found in the repository root. It will not be added to examples.") iface = gr.Interface( fn=process_video_for_gradio, inputs=gr.Video(label="Upload Video File (.mp4)", sources=["upload"]), outputs=[ gr.Video(label="Processed Video with Detections"), gr.Textbox(label="Detection Summary (Events / Status)") ], title="AI Fall Detection from Video", description="Upload a video file (MP4 format recommended) to detect falls. " \ "Processing may take time depending on video length.", examples=examples_list_for_gradio if examples_list_for_gradio else None, # <<<< ใช้ list ใหม่นี้ allow_flagging="never", cache_examples=False ) if __name__ == "__main__": iface.launch()