Spaces:
Sleeping
Sleeping
| import os | |
| # Cấu hình thư mục tạm cho YOLO (Bắt buộc cho HuggingFace) | |
| os.environ["YOLO_CONFIG_DIR"] = "/tmp" | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights | |
| import albumentations as A | |
| from ultralytics import YOLO | |
| from datetime import datetime | |
| import pandas as pd | |
| from collections import deque | |
| from pathlib import Path | |
| # ============================================================ | |
| # 1. MODEL CONFIGURATION (Giữ nguyên logic của bạn) | |
| # ============================================================ | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| MODEL_PATH = "best_model_efficientnet_lstm_v2.pth" | |
| class EfficientNetLSTM(nn.Module): | |
| def __init__(self, hidden_size=256, num_layers=2, dropout=0.5): | |
| super(EfficientNetLSTM, self).__init__() | |
| weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 | |
| self.efficientnet = efficientnet_v2_s(weights=weights) | |
| num_features = self.efficientnet.classifier[1].in_features | |
| self.efficientnet.classifier = nn.Identity() | |
| self.lstm = nn.LSTM(input_size=num_features, hidden_size=hidden_size, num_layers=num_layers, | |
| batch_first=True, dropout=dropout, bidirectional=True) | |
| self.fc = nn.Sequential( | |
| nn.Linear(256*2, 256), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(128, 1) | |
| ) | |
| def forward(self, x): | |
| batch_size, num_frames, c, h, w = x.shape | |
| x = x.view(batch_size * num_frames, c, h, w) | |
| features = self.efficientnet(x) | |
| features = features.view(batch_size, num_frames, -1) | |
| lstm_out, _ = self.lstm(features) | |
| output = self.fc(lstm_out[:, -1, :]) | |
| return output.squeeze() | |
| # Load Models Global | |
| print("⏳ Đang tải models...") | |
| try: | |
| model = EfficientNetLSTM().to(DEVICE) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.eval() | |
| yolo_model = YOLO("yolov8n.pt") | |
| print("✅ Đã tải xong models!") | |
| except Exception as e: | |
| print(f"❌ Lỗi: {e}") | |
| model = None | |
| yolo_model = None | |
| # Transform | |
| transform = A.Compose([ | |
| A.Resize(height=224, width=224), | |
| A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| A.ToTensorV2(), | |
| ]) | |
| # ============================================================ | |
| # 2. SYSTEM CLASS (QUẢN LÝ TRẠNG THÁI) | |
| # ============================================================ | |
| class FallDetectionSystem: | |
| def __init__(self): | |
| # Config | |
| self.num_frames = 32 | |
| self.conf_thres = 0.5 | |
| self.output_dir = Path("fall_videos") | |
| self.output_dir.mkdir(exist_ok=True) | |
| # Realtime Buffers | |
| self.buffer = deque(maxlen=self.num_frames) # Buffer cho model | |
| self.pre_buffer = deque(maxlen=30) # Buffer lưu 30 frame trước khi ngã | |
| self.no_detect_count = 0 | |
| # Recording State | |
| self.is_recording = False | |
| self.video_writer = None | |
| self.current_video_path = None | |
| self.fall_start_time = None | |
| self.fall_frame_count = 0 | |
| # Logging & History | |
| self.log_history = [] # Cho realtime text log | |
| self.saved_videos = [] # List đường dẫn video đã lưu | |
| self.analysis_history = pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"]) | |
| def reset_realtime_state(self): | |
| """Reset trạng thái khi bật lại camera""" | |
| self.buffer.clear() | |
| self.pre_buffer.clear() | |
| self.is_recording = False | |
| if self.video_writer: | |
| self.video_writer.release() | |
| self.video_writer = None | |
| # --- LOGIC TAB 1: VIDEO FILE ANALYSIS --- | |
| def analyze_video(self, video_path): | |
| if model is None: return "Error loading model", self.analysis_history | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Logic lấy 16 frames (như code cũ) | |
| if total_frames >= 32: | |
| indices = np.linspace(0, total_frames - 1, 32, dtype=int) | |
| else: | |
| indices = np.arange(total_frames) | |
| for i in range(total_frames): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| if i in indices: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(transform(image=frame_rgb)['image']) | |
| cap.release() | |
| # Pad frame nếu thiếu | |
| while len(frames) < 32: frames.append(frames[-1]) | |
| # Predict | |
| video_tensor = torch.stack(frames).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| prob = torch.sigmoid(model(video_tensor)).item() | |
| is_fall = prob > 0.5 | |
| result_text = "⚠️ PHÁT HIỆN NGÃ" if is_fall else "✅ AN TOÀN" | |
| timestamp = datetime.now().strftime("%d/%m/%Y %H:%M") | |
| filename = os.path.basename(video_path) | |
| # Cập nhật DataFrame | |
| new_row = pd.DataFrame({ | |
| "Thời gian": [timestamp], | |
| "Video": [filename], | |
| "Kết quả": [result_text], | |
| "Độ tin cậy": [f"{prob*100:.2f}%"] | |
| }) | |
| self.analysis_history = pd.concat([new_row, self.analysis_history], ignore_index=True) | |
| return f"{result_text} ({prob*100:.2f}%)", self.analysis_history | |
| # --- LOGIC TAB 2: REALTIME PROCESSING --- | |
| def process_frame(self, image): | |
| """Hàm xử lý chính cho mỗi frame từ webcam""" | |
| if image is None: return image, "", "", [] | |
| # 1. Chuẩn bị dữ liệu | |
| frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # OpenCV dùng BGR | |
| current_time = datetime.now().strftime('%H:%M:%S') | |
| # Thêm vào pre-buffer (để ghi video lùi lại quá khứ) | |
| self.pre_buffer.append(frame_bgr) | |
| # 2. Detect Người (YOLO) | |
| results = yolo_model(frame_bgr, verbose=False, conf=self.conf_thres) | |
| boxes = results[0].boxes.data.cpu().numpy() | |
| person_box = None | |
| for x1, y1, x2, y2, conf, cls in boxes: | |
| if int(cls) == 0: # Person | |
| person_box = (int(x1), int(y1), int(x2), int(y2)) | |
| break | |
| # Các biến hiển thị UI | |
| status_html = "<div style='background:green; color:white; padding:10px; border-radius:5px'>🟢 AN TOÀN</div>" | |
| log_entry = "" | |
| # --- LOGIC XỬ LÝ (Giống app_new.py) --- | |
| if person_box is None: | |
| self.no_detect_count += 1 | |
| if self.no_detect_count >= 10: | |
| self.buffer.clear() | |
| # Nếu đang ghi video thì dừng lại | |
| self._stop_recording_if_active(save=True) | |
| cv2.putText(frame_bgr, "Khong thay nguoi", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2) | |
| else: | |
| self.no_detect_count = 0 | |
| x1, y1, x2, y2 = person_box | |
| # Thêm frame vào buffer LSTM | |
| frame_tensor = transform(image=image)['image'] # Image đã là RGB từ Gradio | |
| self.buffer.append(frame_tensor) | |
| # Vẽ box người | |
| cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| # Chỉ predict khi đủ 16 frames | |
| if len(self.buffer) == self.num_frames: | |
| video_tensor = torch.stack(list(self.buffer)).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(video_tensor) | |
| prob = torch.sigmoid(output).item() | |
| is_fall = prob > 0.5 | |
| if is_fall: | |
| # --- PHÁT HIỆN NGÃ --- | |
| status_html = "<div style='background:red; color:white; padding:10px; border-radius:5px'>🔴 NGUY HIỂM: TÉ NGÃ</div>" | |
| label = f"TE NGA! ({prob*100:.0f}%)" | |
| cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 0, 255), 3) | |
| cv2.putText(frame_bgr, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
| log_entry = f"<p style='color:#ff4444'>🔴 {current_time}: Phát hiện ngã ({prob*100:.0f}%)</p>" | |
| # BẮT ĐẦU GHI VIDEO (Nếu chưa ghi) | |
| if not self.is_recording: | |
| self._start_recording(frame_bgr) | |
| # Ghi frame hiện tại | |
| if self.video_writer: | |
| self.video_writer.write(frame_bgr) | |
| self.fall_frame_count += 1 | |
| else: | |
| # --- BÌNH THƯỜNG --- | |
| label = f"An toan ({prob*100:.0f}%)" | |
| cv2.putText(frame_bgr, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
| # log_entry = f"<p style='color:#44ff44'>🟢 {current_time}: Bình thường</p>" # Uncomment nếu muốn spam log | |
| # DỪNG GHI VIDEO (Nếu đang ghi) | |
| self._stop_recording_if_active(save=True) | |
| else: | |
| cv2.putText(frame_bgr, f"Buffering: {len(self.buffer)}/32", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2) | |
| # Cập nhật Log | |
| if log_entry: | |
| self.log_history.insert(0, log_entry) | |
| if len(self.log_history) > 50: self.log_history.pop() | |
| log_html_output = "".join(self.log_history) | |
| # Convert back to RGB for Gradio display | |
| frame_rgb_out = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| return frame_rgb_out, status_html, log_html_output, self.saved_videos | |
| # --- HELPER METHODS FOR RECORDING --- | |
| def _start_recording(self, frame_sample): | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| filename = f"fall_detect_{timestamp}.mp4" | |
| filepath = self.output_dir / filename | |
| h, w = frame_sample.shape[:2] | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # mp4v tương thích tốt hơn | |
| self.video_writer = cv2.VideoWriter(str(filepath), fourcc, 20.0, (w, h)) | |
| self.is_recording = True | |
| self.current_video_path = str(filepath) | |
| self.fall_frame_count = 0 | |
| # Ghi lại các frame quá khứ (30 frame trước khi ngã) | |
| for past_frame in self.pre_buffer: | |
| self.video_writer.write(past_frame) | |
| def _stop_recording_if_active(self, save=True): | |
| if self.is_recording and self.video_writer: | |
| self.video_writer.release() | |
| self.video_writer = None | |
| self.is_recording = False | |
| # Logic lưu video | |
| if save and self.fall_frame_count > 10: # Chỉ lưu nếu video đủ dài | |
| self.saved_videos.insert(0, self.current_video_path) | |
| else: | |
| # Xóa file rác nếu video quá ngắn | |
| try: | |
| os.remove(self.current_video_path) | |
| except: pass | |
| # Khởi tạo hệ thống | |
| system = FallDetectionSystem() | |
| # ============================================================ | |
| # 3. GRADIO UI | |
| # ============================================================ | |
| # Custom CSS | |
| css = """ | |
| .status-box { text-align: center; font-size: 1.2em; font-weight: bold; margin-bottom: 10px; } | |
| .log-container { height: 300px; overflow-y: auto; background: #222; padding: 10px; border-radius: 8px; border: 1px solid #444; } | |
| """ | |
| with gr.Blocks(title="Hệ thống Dự đoán Fall", css=css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎈 Hệ thống Phát hiện Té ngã (AI Powered)") | |
| with gr.Tab("📹 Dự đoán Realtime"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Input Webcam | |
| input_cam = gr.Image(sources=["webcam"], type="numpy", label="Camera Input") | |
| # Output đã vẽ box | |
| output_cam = gr.Image(label="Kết quả Xử lý") | |
| with gr.Column(scale=1): | |
| # Trạng thái An toàn/Nguy hiểm | |
| status_html = gr.HTML(value="<div style='background:gray; color:white; padding:10px; border-radius:5px'>⚪ CHỜ CAMERA</div>", elem_classes="status-box") | |
| # Nhật ký Log | |
| gr.Markdown("### 📝 Nhật ký phát hiện") | |
| log_display = gr.HTML(elem_classes="log-container") | |
| # Section Video đã lưu | |
| gr.Markdown("---") | |
| gr.Markdown("### 📂 Video té ngã đã ghi lại tự động") | |
| # Gallery hiển thị video đã lưu | |
| gallery = gr.Gallery(label="Video Té Ngã", columns=3, height="auto", object_fit="contain") | |
| # Sự kiện Stream Realtime | |
| input_cam.stream( | |
| fn=system.process_frame, | |
| inputs=[input_cam], | |
| outputs=[output_cam, status_html, log_display, gallery], | |
| show_progress=False | |
| ) | |
| # Sự kiện xóa buffer khi tắt/bật camera (clear log) | |
| input_cam.clear(fn=system.reset_realtime_state, inputs=None, outputs=None) | |
| with gr.Tab("📹 Dự đoán qua Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Tải video lên") | |
| analyze_btn = gr.Button("🔵 Bắt đầu phân tích", variant="primary") | |
| result_text = gr.Label(label="Kết quả phân tích") | |
| with gr.Column(): | |
| gr.Markdown("### 📊 Lịch sử phân tích") | |
| history_table = gr.Dataframe( | |
| headers=["Thời gian", "Video", "Kết quả", "Độ tin cậy"], | |
| datatype=["str", "str", "str", "str"], | |
| value=pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"]), | |
| interactive=False | |
| ) | |
| clear_hist_btn = gr.Button("🗑️ Xóa lịch sử") | |
| # Sự kiện nút bấm | |
| analyze_btn.click( | |
| fn=system.analyze_video, | |
| inputs=video_input, | |
| outputs=[result_text, history_table] | |
| ) | |
| def clear_history(): | |
| system.analysis_history = pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"]) | |
| return system.analysis_history | |
| clear_hist_btn.click(fn=clear_history, outputs=history_table) | |
| if __name__ == "__main__": | |
| demo.launch() |