fall_detection / app.py
ltlonggg's picture
changes num_frames
d676491
import os
# Cấu hình thư mục tạm cho YOLO (Bắt buộc cho HuggingFace)
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
import gradio as gr
import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
import albumentations as A
from ultralytics import YOLO
from datetime import datetime
import pandas as pd
from collections import deque
from pathlib import Path
# ============================================================
# 1. MODEL CONFIGURATION (Giữ nguyên logic của bạn)
# ============================================================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = "best_model_efficientnet_lstm_v2.pth"
class EfficientNetLSTM(nn.Module):
def __init__(self, hidden_size=256, num_layers=2, dropout=0.5):
super(EfficientNetLSTM, self).__init__()
weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1
self.efficientnet = efficientnet_v2_s(weights=weights)
num_features = self.efficientnet.classifier[1].in_features
self.efficientnet.classifier = nn.Identity()
self.lstm = nn.LSTM(input_size=num_features, hidden_size=hidden_size, num_layers=num_layers,
batch_first=True, dropout=dropout, bidirectional=True)
self.fc = nn.Sequential(
nn.Linear(256*2, 256), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(128, 1)
)
def forward(self, x):
batch_size, num_frames, c, h, w = x.shape
x = x.view(batch_size * num_frames, c, h, w)
features = self.efficientnet(x)
features = features.view(batch_size, num_frames, -1)
lstm_out, _ = self.lstm(features)
output = self.fc(lstm_out[:, -1, :])
return output.squeeze()
# Load Models Global
print("⏳ Đang tải models...")
try:
model = EfficientNetLSTM().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
yolo_model = YOLO("yolov8n.pt")
print("✅ Đã tải xong models!")
except Exception as e:
print(f"❌ Lỗi: {e}")
model = None
yolo_model = None
# Transform
transform = A.Compose([
A.Resize(height=224, width=224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
A.ToTensorV2(),
])
# ============================================================
# 2. SYSTEM CLASS (QUẢN LÝ TRẠNG THÁI)
# ============================================================
class FallDetectionSystem:
def __init__(self):
# Config
self.num_frames = 32
self.conf_thres = 0.5
self.output_dir = Path("fall_videos")
self.output_dir.mkdir(exist_ok=True)
# Realtime Buffers
self.buffer = deque(maxlen=self.num_frames) # Buffer cho model
self.pre_buffer = deque(maxlen=30) # Buffer lưu 30 frame trước khi ngã
self.no_detect_count = 0
# Recording State
self.is_recording = False
self.video_writer = None
self.current_video_path = None
self.fall_start_time = None
self.fall_frame_count = 0
# Logging & History
self.log_history = [] # Cho realtime text log
self.saved_videos = [] # List đường dẫn video đã lưu
self.analysis_history = pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"])
def reset_realtime_state(self):
"""Reset trạng thái khi bật lại camera"""
self.buffer.clear()
self.pre_buffer.clear()
self.is_recording = False
if self.video_writer:
self.video_writer.release()
self.video_writer = None
# --- LOGIC TAB 1: VIDEO FILE ANALYSIS ---
def analyze_video(self, video_path):
if model is None: return "Error loading model", self.analysis_history
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Logic lấy 16 frames (như code cũ)
if total_frames >= 32:
indices = np.linspace(0, total_frames - 1, 32, dtype=int)
else:
indices = np.arange(total_frames)
for i in range(total_frames):
ret, frame = cap.read()
if not ret: break
if i in indices:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(transform(image=frame_rgb)['image'])
cap.release()
# Pad frame nếu thiếu
while len(frames) < 32: frames.append(frames[-1])
# Predict
video_tensor = torch.stack(frames).unsqueeze(0).to(DEVICE)
with torch.no_grad():
prob = torch.sigmoid(model(video_tensor)).item()
is_fall = prob > 0.5
result_text = "⚠️ PHÁT HIỆN NGÃ" if is_fall else "✅ AN TOÀN"
timestamp = datetime.now().strftime("%d/%m/%Y %H:%M")
filename = os.path.basename(video_path)
# Cập nhật DataFrame
new_row = pd.DataFrame({
"Thời gian": [timestamp],
"Video": [filename],
"Kết quả": [result_text],
"Độ tin cậy": [f"{prob*100:.2f}%"]
})
self.analysis_history = pd.concat([new_row, self.analysis_history], ignore_index=True)
return f"{result_text} ({prob*100:.2f}%)", self.analysis_history
# --- LOGIC TAB 2: REALTIME PROCESSING ---
def process_frame(self, image):
"""Hàm xử lý chính cho mỗi frame từ webcam"""
if image is None: return image, "", "", []
# 1. Chuẩn bị dữ liệu
frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # OpenCV dùng BGR
current_time = datetime.now().strftime('%H:%M:%S')
# Thêm vào pre-buffer (để ghi video lùi lại quá khứ)
self.pre_buffer.append(frame_bgr)
# 2. Detect Người (YOLO)
results = yolo_model(frame_bgr, verbose=False, conf=self.conf_thres)
boxes = results[0].boxes.data.cpu().numpy()
person_box = None
for x1, y1, x2, y2, conf, cls in boxes:
if int(cls) == 0: # Person
person_box = (int(x1), int(y1), int(x2), int(y2))
break
# Các biến hiển thị UI
status_html = "<div style='background:green; color:white; padding:10px; border-radius:5px'>🟢 AN TOÀN</div>"
log_entry = ""
# --- LOGIC XỬ LÝ (Giống app_new.py) ---
if person_box is None:
self.no_detect_count += 1
if self.no_detect_count >= 10:
self.buffer.clear()
# Nếu đang ghi video thì dừng lại
self._stop_recording_if_active(save=True)
cv2.putText(frame_bgr, "Khong thay nguoi", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2)
else:
self.no_detect_count = 0
x1, y1, x2, y2 = person_box
# Thêm frame vào buffer LSTM
frame_tensor = transform(image=image)['image'] # Image đã là RGB từ Gradio
self.buffer.append(frame_tensor)
# Vẽ box người
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Chỉ predict khi đủ 16 frames
if len(self.buffer) == self.num_frames:
video_tensor = torch.stack(list(self.buffer)).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(video_tensor)
prob = torch.sigmoid(output).item()
is_fall = prob > 0.5
if is_fall:
# --- PHÁT HIỆN NGÃ ---
status_html = "<div style='background:red; color:white; padding:10px; border-radius:5px'>🔴 NGUY HIỂM: TÉ NGÃ</div>"
label = f"TE NGA! ({prob*100:.0f}%)"
cv2.rectangle(frame_bgr, (x1, y1), (x2, y2), (0, 0, 255), 3)
cv2.putText(frame_bgr, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
log_entry = f"<p style='color:#ff4444'>🔴 {current_time}: Phát hiện ngã ({prob*100:.0f}%)</p>"
# BẮT ĐẦU GHI VIDEO (Nếu chưa ghi)
if not self.is_recording:
self._start_recording(frame_bgr)
# Ghi frame hiện tại
if self.video_writer:
self.video_writer.write(frame_bgr)
self.fall_frame_count += 1
else:
# --- BÌNH THƯỜNG ---
label = f"An toan ({prob*100:.0f}%)"
cv2.putText(frame_bgr, label, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
# log_entry = f"<p style='color:#44ff44'>🟢 {current_time}: Bình thường</p>" # Uncomment nếu muốn spam log
# DỪNG GHI VIDEO (Nếu đang ghi)
self._stop_recording_if_active(save=True)
else:
cv2.putText(frame_bgr, f"Buffering: {len(self.buffer)}/32", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
# Cập nhật Log
if log_entry:
self.log_history.insert(0, log_entry)
if len(self.log_history) > 50: self.log_history.pop()
log_html_output = "".join(self.log_history)
# Convert back to RGB for Gradio display
frame_rgb_out = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
return frame_rgb_out, status_html, log_html_output, self.saved_videos
# --- HELPER METHODS FOR RECORDING ---
def _start_recording(self, frame_sample):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"fall_detect_{timestamp}.mp4"
filepath = self.output_dir / filename
h, w = frame_sample.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # mp4v tương thích tốt hơn
self.video_writer = cv2.VideoWriter(str(filepath), fourcc, 20.0, (w, h))
self.is_recording = True
self.current_video_path = str(filepath)
self.fall_frame_count = 0
# Ghi lại các frame quá khứ (30 frame trước khi ngã)
for past_frame in self.pre_buffer:
self.video_writer.write(past_frame)
def _stop_recording_if_active(self, save=True):
if self.is_recording and self.video_writer:
self.video_writer.release()
self.video_writer = None
self.is_recording = False
# Logic lưu video
if save and self.fall_frame_count > 10: # Chỉ lưu nếu video đủ dài
self.saved_videos.insert(0, self.current_video_path)
else:
# Xóa file rác nếu video quá ngắn
try:
os.remove(self.current_video_path)
except: pass
# Khởi tạo hệ thống
system = FallDetectionSystem()
# ============================================================
# 3. GRADIO UI
# ============================================================
# Custom CSS
css = """
.status-box { text-align: center; font-size: 1.2em; font-weight: bold; margin-bottom: 10px; }
.log-container { height: 300px; overflow-y: auto; background: #222; padding: 10px; border-radius: 8px; border: 1px solid #444; }
"""
with gr.Blocks(title="Hệ thống Dự đoán Fall", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎈 Hệ thống Phát hiện Té ngã (AI Powered)")
with gr.Tab("📹 Dự đoán Realtime"):
with gr.Row():
with gr.Column(scale=2):
# Input Webcam
input_cam = gr.Image(sources=["webcam"], type="numpy", label="Camera Input")
# Output đã vẽ box
output_cam = gr.Image(label="Kết quả Xử lý")
with gr.Column(scale=1):
# Trạng thái An toàn/Nguy hiểm
status_html = gr.HTML(value="<div style='background:gray; color:white; padding:10px; border-radius:5px'>⚪ CHỜ CAMERA</div>", elem_classes="status-box")
# Nhật ký Log
gr.Markdown("### 📝 Nhật ký phát hiện")
log_display = gr.HTML(elem_classes="log-container")
# Section Video đã lưu
gr.Markdown("---")
gr.Markdown("### 📂 Video té ngã đã ghi lại tự động")
# Gallery hiển thị video đã lưu
gallery = gr.Gallery(label="Video Té Ngã", columns=3, height="auto", object_fit="contain")
# Sự kiện Stream Realtime
input_cam.stream(
fn=system.process_frame,
inputs=[input_cam],
outputs=[output_cam, status_html, log_display, gallery],
show_progress=False
)
# Sự kiện xóa buffer khi tắt/bật camera (clear log)
input_cam.clear(fn=system.reset_realtime_state, inputs=None, outputs=None)
with gr.Tab("📹 Dự đoán qua Video"):
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Tải video lên")
analyze_btn = gr.Button("🔵 Bắt đầu phân tích", variant="primary")
result_text = gr.Label(label="Kết quả phân tích")
with gr.Column():
gr.Markdown("### 📊 Lịch sử phân tích")
history_table = gr.Dataframe(
headers=["Thời gian", "Video", "Kết quả", "Độ tin cậy"],
datatype=["str", "str", "str", "str"],
value=pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"]),
interactive=False
)
clear_hist_btn = gr.Button("🗑️ Xóa lịch sử")
# Sự kiện nút bấm
analyze_btn.click(
fn=system.analyze_video,
inputs=video_input,
outputs=[result_text, history_table]
)
def clear_history():
system.analysis_history = pd.DataFrame(columns=["Thời gian", "Video", "Kết quả", "Độ tin cậy"])
return system.analysis_history
clear_hist_btn.click(fn=clear_history, outputs=history_table)
if __name__ == "__main__":
demo.launch()