|
import os |
|
import sys |
|
import asyncio |
|
import tempfile |
|
import traceback |
|
|
|
os.environ["HOME"] = "/tmp" |
|
os.environ["STREAMLIT_HOME"] = "/tmp" |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" |
|
|
|
if sys.platform.startswith('linux'): |
|
try: |
|
asyncio.get_event_loop() |
|
except RuntimeError: |
|
asyncio.set_event_loop(asyncio.new_event_loop()) |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import streamlit as st |
|
from streamlit_webrtc import VideoProcessorBase, webrtc_streamer, RTCConfiguration |
|
from huggingface_hub import hf_hub_download |
|
from twilio.rest import Client |
|
|
|
account_sid = os.environ.get("ACCOUNT_SID") |
|
auth_token = os.environ.get("AUTH_TOKEN") |
|
ICE_SERVERS = [{"urls": ["stun:stun.l.google.com:19302"]}] |
|
if account_sid and auth_token: |
|
try: |
|
twilio_client = Client(account_sid, auth_token) |
|
token = twilio_client.tokens.create() |
|
try: |
|
ICE_SERVERS = [ |
|
server for server in token.ice_servers |
|
if any("udp" in url for url in ([server["urls"]] if isinstance(server["urls"], str) else server["urls"])) |
|
] |
|
st.success("✅ Using Twilio TURN/STUN servers with UDP") |
|
except Exception as e: |
|
ICE_SERVERS = token.ice_servers |
|
st.success("✅ Using Twilio TURN/STUN servers") |
|
except Exception as e: |
|
st.error(f"❌ Failed to get ICE servers from Twilio: {e}") |
|
st.text(traceback.format_exc()) |
|
else: |
|
st.warning("⚠️ Twilio credentials not set. Falling back to STUN-only.") |
|
|
|
import tensorflow as tf |
|
gpus = tf.config.experimental.list_physical_devices('GPU') |
|
if gpus: |
|
try: |
|
for gpu in gpus: |
|
tf.config.experimental.set_memory_growth(gpu, True) |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
|
|
from nets import get_model_from_name |
|
from utils.utils import (cvtColor, get_classes, letterbox_image, preprocess_input) |
|
|
|
|
|
|
|
cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache") |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
class Classification(object): |
|
_defaults = { |
|
"model_path": hf_hub_download( |
|
repo_id="sudo-paras-shah/micro-expression-casme2", |
|
filename="ep089.weights.h5", |
|
cache_dir=cache_dir |
|
), |
|
"classes_path": 'src/model_data/cls_classes.txt', |
|
"input_shape": [224, 224], |
|
"backbone": 'vgg16', |
|
"alpha": 0.25 |
|
} |
|
|
|
@classmethod |
|
def get_defaults(cls, n): |
|
if n in cls._defaults: |
|
return cls._defaults[n] |
|
else: |
|
return "Unrecognized attribute name '" + n + "'" |
|
|
|
def __init__(self, **kwargs): |
|
self.__dict__.update(self._defaults) |
|
for name, value in kwargs.items(): |
|
setattr(self, name, value) |
|
self.class_names, self.num_classes = get_classes(self.classes_path) |
|
self.generate() |
|
|
|
def generate(self): |
|
model_path = os.path.expanduser(self.model_path) |
|
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.' |
|
if self.backbone == "mobilenet": |
|
self.model = get_model_from_name[self.backbone]( |
|
input_shape=[self.input_shape[0], self.input_shape[1], 3], |
|
classes=self.num_classes, |
|
alpha=self.alpha |
|
) |
|
else: |
|
self.model = get_model_from_name[self.backbone]( |
|
input_shape=[self.input_shape[0], self.input_shape[1], 3], |
|
classes=self.num_classes |
|
) |
|
self.model.load_weights(self.model_path) |
|
print('{} model, and classes {} loaded.'.format(model_path, self.class_names)) |
|
|
|
def detect_image(self, image): |
|
image = cvtColor(image) |
|
image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]]) |
|
image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0) |
|
preds = self.model.predict(image_data)[0] |
|
class_name = self.class_names[np.argmax(preds)] |
|
probability = np.max(preds) |
|
return class_name, probability |
|
|
|
|
|
if __name__ == '__main__': |
|
@st.cache_resource |
|
def get_model(): |
|
return Classification() |
|
|
|
classificator = get_model() |
|
face_cascade = cv2.CascadeClassifier( |
|
cv2.data.haarcascades + 'haarcascade_frontalface_alt.xml' |
|
) |
|
|
|
if face_cascade.empty(): |
|
st.error("Failed to load Haarcascade XML. Check the path.") |
|
|
|
st.title("Real-Time Micro-Emotion Recognition") |
|
st.write("Turn on your camera and detect emotions in real-time.") |
|
|
|
def face_detect(img): |
|
try: |
|
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
faces = face_cascade.detectMultiScale( |
|
img_gray, |
|
scaleFactor=1.1, |
|
minNeighbors=1, |
|
minSize=(30, 30) |
|
) |
|
return img, img_gray, faces |
|
except Exception as e: |
|
st.error(f"OpenCV face detection error: {e}") |
|
return img, np.zeros_like(img), [] |
|
|
|
def map_emotion_to_class(emotion): |
|
positive = ['happiness', 'happy'] |
|
negative = ['disgust', 'sadness', 'fear', 'sad', 'angry', 'disgusted'] |
|
surprise = ['surprise'] |
|
others = ['repression', 'tense', 'neutral', 'others'] |
|
e = emotion.lower() |
|
if any(p in e for p in positive): |
|
return 'Positive' |
|
elif any(n in e for n in negative): |
|
return 'Negative' |
|
elif any(s in e for s in surprise): |
|
return 'Surprise' |
|
else: |
|
return 'Others' |
|
|
|
if 'emotion_history' not in st.session_state: |
|
st.session_state['emotion_history'] = [] |
|
|
|
class EmotionRecognitionProcessor(VideoProcessorBase): |
|
def __init__(self): |
|
self.last_class = None |
|
self.rapid_change_count = 0 |
|
self.frame_count = 0 |
|
self.last_faces = [] |
|
self.last_img_gray = None |
|
self.last_results = [] |
|
|
|
def recv(self, frame): |
|
border_color = (255, 0, 0) |
|
font_color = (0, 0, 255) |
|
try: |
|
img = frame.to_ndarray(format="bgr24") |
|
self.frame_count += 1 |
|
|
|
|
|
if self.frame_count % 2 == 0: |
|
img_disp, img_gray, faces = face_detect(img) |
|
self.last_faces = faces |
|
self.last_img_gray = img_gray |
|
self.last_results = [] |
|
current_class = None |
|
|
|
if len(faces) == 0: |
|
cv2.putText( |
|
img_disp, 'No Face Detect.', (2, 20), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1 |
|
) |
|
|
|
for (x, y, w, h) in faces: |
|
x1, y1 = max(x - 10, 0), max(y - 10, 0) |
|
x2 = min(x + w + 10, img_disp.shape[1]) |
|
y2 = min(y + h + 10, img_disp.shape[0]) |
|
|
|
face_img_gray = img_gray[y1:y2, x1:x2] |
|
if face_img_gray.size == 0: |
|
continue |
|
face_img_pil = Image.fromarray(face_img_gray) |
|
emotion, probability = classificator.detect_image(face_img_pil) |
|
emotion_class = map_emotion_to_class(emotion) |
|
|
|
self.last_results.append((x1, y1, x2, y2, emotion, probability, emotion_class)) |
|
current_class = emotion_class |
|
|
|
if current_class: |
|
history = st.session_state['emotion_history'] |
|
history.append(current_class) |
|
if len(history) > 10: |
|
history.pop(0) |
|
if len(history) >= 3 and len(set(history[-3:])) > 1: |
|
self.rapid_change_count += 1 |
|
else: |
|
self.rapid_change_count = 0 |
|
|
|
else: |
|
img_disp = img.copy() |
|
img_gray = self.last_img_gray |
|
faces = self.last_faces |
|
for (x1, y1, x2, y2, emotion, probability, emotion_class) in self.last_results: |
|
cv2.rectangle( |
|
img_disp, |
|
(x1, y1), |
|
(x2, y2), |
|
border_color, |
|
thickness=2 |
|
) |
|
cv2.putText( |
|
img_disp, emotion, (x1 + 30, y1 - 30), |
|
cv2.FONT_HERSHEY_SIMPLEX, 1, font_color, 1 |
|
) |
|
cv2.putText( |
|
img_disp, str(round(probability, 3)), (x1 + 30, y1 - 50), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.3, font_color, 1 |
|
) |
|
|
|
if len(faces) == 0: |
|
cv2.putText( |
|
img_disp, 'No Face Detect.', (2, 20), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1 |
|
) |
|
|
|
return frame.from_ndarray(img_disp, format="bgr24") |
|
except Exception as e: |
|
st.error(f"Error in video processing: {e}") |
|
return frame |
|
|
|
RTC_CONFIGURATION = RTCConfiguration({"iceServers": ICE_SERVERS}) |
|
|
|
webrtc_streamer( |
|
key="emotion-detection", |
|
video_processor_factory=EmotionRecognitionProcessor, |
|
rtc_configuration=RTC_CONFIGURATION, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
) |
|
|
|
history = st.session_state['emotion_history'] |
|
if len(history) >= 3 and len(set(history[-3:])) > 1: |
|
st.warning( |
|
"⚠️ Rapid changes in your detected emotional state were observed. " |
|
"Micro-expressions may not always reflect your true feelings. " |
|
"If you feel emotionally unstable or distressed, " |
|
"consider reaching out to a mental health professional, " |
|
"talking it over with a close person or taking a break." |
|
) |
|
|