Spaces:
Sleeping
Sleeping
import time | |
import os | |
import logging | |
import av | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
from utils.download import download_file | |
from utils.turn import get_ice_servers | |
from mtcnn import MTCNN # Import MTCNN for face detection | |
from PIL import Image, ImageDraw # Import PIL for image processing | |
from transformers import pipeline # Import Hugging Face transformers pipeline | |
import requests | |
from io import BytesIO # Import for handling byte streams | |
import yt_dlp | |
# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# Update below string to set display title of analysis | |
# Appropriate imports needed for analysis | |
# Initialize MTCNN for face detection | |
mtcnn = MTCNN() | |
# Initialize the Hugging Face pipeline for facial emotion detection | |
emotion_pipeline = pipeline("image-classification", | |
model="trpakov/vit-face-expression") | |
# Default title - "Facial Sentiment Analysis" | |
ANALYSIS_TITLE = "Facial Sentiment Analysis" | |
# CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# | |
# | |
# Function to analyze an input frame and generate an analyzed frame | |
# This function takes an input video frame, detects faces in it using MTCNN, | |
# then for each detected face, it analyzes the sentiment (emotion) using the analyze_sentiment function, | |
# draws a rectangle around the face, and overlays the detected emotion on the frame. | |
# It also records the time taken to process the frame and stores it in a global container. | |
# Constants for text and line size in the output image | |
TEXT_SIZE = 1 | |
LINE_SIZE = 2 | |
# Set analysis results in img_container and result queue for display | |
# img_container["input"] - holds the input frame contents - of type np.ndarray | |
# img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray | |
# img_container["analysis_time"] - holds how long the analysis has taken in miliseconds | |
# img_container["detections"] - holds the analysis metadata results | |
def analyze_frame(frame: np.ndarray): | |
start_time = time.time() # Start timing the analysis | |
img_container["input"] = frame # Store the input frame | |
frame = frame.copy() # Create a copy of the frame to modify | |
results = mtcnn.detect_faces(frame) # Detect faces in the frame | |
for result in results: | |
x, y, w, h = result["box"] # Get the bounding box of the detected face | |
face = frame[y: y + h, x: x + w] # Extract the face from the frame | |
# Analyze the sentiment of the face | |
sentiment = analyze_sentiment(face) | |
result["label"] = sentiment | |
# Draw a rectangle around the face | |
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), LINE_SIZE) | |
text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[ | |
0 | |
] | |
text_x = x | |
text_y = y - 10 | |
background_tl = (text_x, text_y - text_size[1]) | |
background_br = (text_x + text_size[0], text_y + 5) | |
# Draw a black background for the text | |
cv2.rectangle(frame, background_tl, background_br, | |
(0, 0, 0), cv2.FILLED) | |
# Put the sentiment text on the image | |
cv2.putText( | |
frame, | |
sentiment, | |
(text_x, text_y), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
TEXT_SIZE, | |
(255, 255, 255), | |
2, | |
) | |
end_time = time.time() # End timing the analysis | |
execution_time_ms = round( | |
(end_time - start_time) * 1000, 2 | |
) # Calculate execution time in milliseconds | |
# Store the execution time | |
img_container["analysis_time"] = execution_time_ms | |
# store the detections | |
img_container["detections"] = results | |
img_container["analyzed"] = frame # Store the analyzed frame | |
return # End of the function | |
# Function to analyze the sentiment (emotion) of a detected face | |
# This function converts the face from BGR to RGB format, then converts it to a PIL image, | |
# uses a pre-trained emotion detection model to get emotion predictions, | |
# and finally returns the most dominant emotion detected. | |
def analyze_sentiment(face): | |
# Convert face to RGB format | |
rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(rgb_face) # Convert to PIL image | |
results = emotion_pipeline(pil_image) # Run emotion detection on the image | |
dominant_emotion = max(results, key=lambda x: x["score"])[ | |
"label" | |
] # Get the dominant emotion | |
return dominant_emotion # Return the detected emotion | |
# | |
# | |
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED) | |
# | |
# | |
# Suppress FFmpeg logs | |
os.environ["FFMPEG_LOG_LEVEL"] = "quiet" | |
# Suppress Streamlit logs using the logging module | |
logging.getLogger("streamlit").setLevel(logging.ERROR) | |
# Container to hold image data and analysis results | |
img_container = {"input": None, "analyzed": None, | |
"analysis_time": None, "detections": None} | |
# Logger for debugging and information | |
logger = logging.getLogger(__name__) | |
# Callback function to process video frames | |
# This function is called for each video frame in the WebRTC stream. | |
# It converts the frame to a numpy array in RGB format, analyzes the frame, | |
# and returns the original frame. | |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: | |
# Convert frame to numpy array in RGB format | |
img = frame.to_ndarray(format="rgb24") | |
analyze_frame(img) # Analyze the frame | |
return frame # Return the original frame | |
# Get ICE servers for WebRTC | |
ice_servers = get_ice_servers() | |
# Streamlit UI configuration | |
st.set_page_config(layout="wide") | |
# Custom CSS for the Streamlit page | |
st.markdown( | |
""" | |
<style> | |
.main { | |
padding: 2rem; | |
} | |
h1, h2, h3 { | |
font-family: 'Arial', sans-serif; | |
} | |
h1 { | |
font-weight: 700; | |
font-size: 2.5rem; | |
} | |
h2 { | |
font-weight: 600; | |
font-size: 2rem; | |
} | |
h3 { | |
font-weight: 500; | |
font-size: 1.5rem; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Streamlit page title and subtitle | |
st.title("Computer Vision Playground") | |
# Add a link to the README file | |
st.markdown( | |
""" | |
<div style="text-align: left;"> | |
<p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md" | |
target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.subheader(ANALYSIS_TITLE) | |
# Columns for input and output streams | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Input Stream") | |
input_subheader = st.empty() | |
input_placeholder = st.empty() # Placeholder for input frame | |
st.subheader("Input Options") | |
# WebRTC streamer to get video input from the webcam | |
webrtc_ctx = webrtc_streamer( | |
key="input-webcam", | |
mode=WebRtcMode.SENDONLY, | |
rtc_configuration=ice_servers, | |
video_frame_callback=video_frame_callback, | |
media_stream_constraints={"video": True, "audio": False}, | |
async_processing=True, | |
) | |
# File uploader for images | |
st.subheader("Upload an Image") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Text input for image URL | |
st.subheader("Or Enter Image URL") | |
image_url = st.text_input("Image URL") | |
# Text input for YouTube URL | |
st.subheader("Enter a YouTube URL") | |
youtube_url = st.text_input("YouTube URL") | |
# File uploader for videos | |
st.subheader("Upload a Video") | |
uploaded_video = st.file_uploader( | |
"Choose a video...", type=["mp4", "avi", "mov", "mkv"] | |
) | |
# Text input for video URL | |
st.subheader("Or Enter Video Download URL") | |
video_url = st.text_input("Video URL") | |
# Streamlit footer | |
st.markdown( | |
""" | |
<div style="text-align: center; margin-top: 2rem;"> | |
<p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Function to initialize the analysis UI | |
# This function sets up the placeholders and UI elements in the analysis section. | |
# It creates placeholders for input and output frames, analysis time, and detected labels. | |
def analysis_init(): | |
global analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder | |
with col2: | |
st.header("Analysis") | |
input_subheader.subheader("Input Frame") | |
st.subheader("Output Frame") | |
output_placeholder = st.empty() # Placeholder for output frame | |
analysis_time = st.empty() # Placeholder for analysis time | |
show_labels = st.checkbox( | |
"Show the detected labels", value=True | |
) # Checkbox to show/hide labels | |
labels_placeholder = st.empty() # Placeholder for labels | |
# Function to publish frames and results to the Streamlit UI | |
# This function retrieves the latest frames and results from the global container and result queue, | |
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels. | |
def publish_frame(): | |
img = img_container["input"] | |
if img is None: | |
return | |
input_placeholder.image(img, channels="RGB") # Display the input frame | |
analyzed = img_container["analyzed"] | |
if analyzed is None: | |
return | |
# Display the analyzed frame | |
output_placeholder.image(analyzed, channels="RGB") | |
time = img_container["analysis_time"] | |
if time is None: | |
return | |
# Display the analysis time | |
analysis_time.text(f"Analysis Time: {time} ms") | |
detections = img_container["detections"] | |
if detections is None: | |
return | |
if show_labels: | |
labels_placeholder.table( | |
detections | |
) # Display labels if the checkbox is checked | |
# If the WebRTC streamer is playing, initialize and publish frames | |
if webrtc_ctx.state.playing: | |
analysis_init() # Initialize the analysis UI | |
while True: | |
publish_frame() # Publish the frames and results | |
time.sleep(0.1) # Delay to control frame rate | |
# If an image is uploaded or a URL is provided, process the image | |
if uploaded_file is not None or image_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) # Open the uploaded image | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
else: | |
response = requests.get(image_url) # Download the image from the URL | |
# Open the downloaded image | |
image = Image.open(BytesIO(response.content)) | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
analyze_frame(img) # Analyze the image | |
publish_frame() # Publish the results | |
# Function to process video files | |
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis, | |
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels. | |
def process_video(video_path): | |
cap = cv2.VideoCapture(video_path) # Open the video file | |
while cap.isOpened(): | |
ret, frame = cap.read() # Read a frame from the video | |
if not ret: | |
break # Exit the loop if no more frames are available | |
# Convert the frame from BGR to RGB format | |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# Analyze the frame for face detection and sentiment analysis | |
analyze_frame(rgb_frame) | |
publish_frame() # Publish the results | |
cap.release() # Release the video capture object | |
# Function to get the video stream URL from YouTube using yt-dlp | |
def get_youtube_stream_url(youtube_url): | |
ydl_opts = { | |
'format': 'best[ext=mp4]', | |
'quiet': True, | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(youtube_url, download=False) | |
stream_url = info_dict['url'] | |
return stream_url | |
# If a YouTube URL is provided, process the video | |
if youtube_url: | |
analysis_init() # Initialize the analysis UI | |
stream_url = get_youtube_stream_url(youtube_url) | |
process_video(stream_url) # Process the video | |
# If a video is uploaded or a URL is provided, process the video | |
if uploaded_video is not None or video_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_video is not None: | |
video_path = uploaded_video.name # Get the name of the uploaded video | |
with open(video_path, "wb") as f: | |
# Save the uploaded video to a file | |
f.write(uploaded_video.getbuffer()) | |
else: | |
# Download the video from the URL | |
video_path = download_file(video_url) | |
process_video(video_path) # Process the video | |