VioMobileNet / inference.py
quochuy
update tcp
b01f428
# inference.py
import cv2
import tensorflow as tf
import time
import base64
import datetime
import os
# Load model 1 lần duy nhất khi import
MODEL_PATH = os.getenv('MODEL_PATH', 'model')
RESOLUTION = int(os.getenv('RESOLUTION', 172))
CONFIDENCE_THRESHOLD = 0.65
print(f"Loading MoViNet from {MODEL_PATH}...")
model = tf.saved_model.load(MODEL_PATH)
infer = model.signatures['serving_default']
print("Model loaded!")
def get_init_states():
dummy = tf.zeros([1, 1, RESOLUTION, RESOLUTION, 3], dtype=tf.float32)
return model.init_states(tf.shape(dummy))
os.environ["OPENCV_FFMPEG_INTERRUPT_TIMEOUT"] = "60000"
os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;tcp"
class VideoProcessor:
def __init__(self):
self.running = False
def start_processing(self, rtsp_url, result_queue):
self.running = True
print(f"Trying to open: {rtsp_url}")
cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
if not cap.isOpened():
result_queue.put({"error": "Cannot open RTSP URL"})
return
print("RTSP Stream Opened successfully with TCP!")
states = get_init_states()
# Logic quản lý sự kiện (Cooldown)
in_event = False
event_start_time = None
cooldown_counter = 0
COOLDOWN_LIMIT = 30 # Frames (~2-3s)
while self.running:
ret, frame = cap.read()
if not ret:
break
# 1. Inference
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
resized = tf.image.resize(rgb, [RESOLUTION, RESOLUTION])
input_tensor = tf.cast(resized, tf.float32) / 255.0
input_tensor = input_tensor[tf.newaxis, tf.newaxis, ...]
outputs = infer(image=input_tensor, **states)
logits = outputs['logits']
states = {k: v for k, v in outputs.items() if k != 'logits'}
probs = tf.nn.softmax(logits, axis=-1)[0]
fight_conf = float(probs[0])
norm_conf = float(probs[1])
is_violence = (fight_conf > norm_conf) and (fight_conf > CONFIDENCE_THRESHOLD)
# 2. Logic xử lý kết quả để gửi về
current_time = datetime.datetime.now()
msg = None
if is_violence:
cooldown_counter = 0
if not in_event:
in_event = True
event_start_time = current_time
# START: Gửi ảnh bằng chứng
small_frame = cv2.resize(frame, (640, 360))
_, buffer = cv2.imencode('.jpg', small_frame)
img_base64 = base64.b64encode(buffer).decode('utf-8')
msg = {
"type": "START",
"timestamp": current_time.isoformat(),
"score": fight_conf,
"image": img_base64
}
else:
if in_event:
cooldown_counter += 1
if cooldown_counter >= COOLDOWN_LIMIT:
# END: Gửi thời lượng
duration = (current_time - event_start_time).total_seconds()
msg = {
"type": "END",
"timestamp": current_time.isoformat(),
"duration": duration
}
in_event = False
# Nếu có tin quan trọng thì đẩy vào hàng đợi gửi về Laptop
if msg:
result_queue.put(msg)
# Tùy chọn: Gửi Heartbeat mỗi 5s để biết model vẫn sống (nếu cần)
cap.release()
result_queue.put({"status": "Stream stopped"})