Spaces:
Running
Running
import os | |
import sys | |
import tempfile | |
import os.path as osp | |
from PIL import Image | |
from io import BytesIO | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from PIL import ImageOps | |
from matplotlib import pyplot as plt | |
import altair as alt | |
root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
sys.path.append(root_path) | |
from registry_utils import import_registered_modules | |
from app_utils import ( | |
extract_frames, | |
is_image, | |
is_video, | |
convert_diameter, | |
overlay_text_on_frame, | |
process_frames, | |
process_video, | |
resize_frame, | |
) | |
import_registered_modules() | |
CAM_METHODS = ["CAM"] | |
TV_MODELS = ["ResNet18", "ResNet50"] | |
SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"] | |
UPSCALE = [2, 4] | |
UPSCALE_METHODS = ["BILINEAR", "BICUBIC"] | |
LABEL_MAP = ["left_pupil", "right_pupil"] | |
def main(): | |
st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") | |
st.title("EyeDentify Playground") | |
cols = st.columns((1, 1)) | |
cols[0].header("Input") | |
cols[-1].header("Prediction") | |
st.sidebar.title("Upload Face or Eye") | |
uploaded_file = st.sidebar.file_uploader( | |
"Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"] | |
) | |
if uploaded_file is not None: | |
file_extension = uploaded_file.name.split(".")[-1] | |
if is_image(file_extension): | |
input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB") | |
# NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image. | |
input_img = ImageOps.exif_transpose(input_img) | |
input_img = resize_frame(input_img, max_width=640, max_height=480) | |
input_img = resize_frame(input_img, max_width=640, max_height=480) | |
cols[0].image(input_img, use_column_width=True) | |
st.session_state.total_frames = 1 | |
elif is_video(file_extension): | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(uploaded_file.read()) | |
video_path = tfile.name | |
video_frames = extract_frames(video_path) | |
cols[0].video(video_path) | |
st.session_state.total_frames = len(video_frames) | |
st.session_state.current_frame = 0 | |
st.session_state.frame_placeholder = cols[0].empty() | |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
st.sidebar.title("Setup") | |
pupil_selection = st.sidebar.selectbox( | |
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation" | |
) | |
tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models") | |
blink_detection = st.sidebar.checkbox("Detect Blinks") | |
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True) | |
if st.sidebar.button("Predict Diameter & Compute CAM"): | |
if uploaded_file is None: | |
st.sidebar.error("Please upload an image or video") | |
else: | |
with st.spinner("Analyzing..."): | |
if is_image(file_extension): | |
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
cols, | |
[input_img], | |
tv_model, | |
pupil_selection, | |
cam_method=CAM_METHODS[-1], | |
blink_detection=blink_detection, | |
) | |
# for ff in face_frames: | |
# if ff["has_face"]: | |
# cols[1].image(face_frames[0]["img"], use_column_width=True) | |
input_frames_keys = input_frames.keys() | |
video_cols = cols[1].columns(len(input_frames_keys)) | |
for i, eye_type in enumerate(input_frames_keys): | |
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True) | |
output_frames_keys = output_frames.keys() | |
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5)) | |
for i, eye_type in enumerate(output_frames_keys): | |
height, width, c = output_frames[eye_type][0].shape | |
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True) | |
frame = np.zeros((height, width, c), dtype=np.uint8) | |
text = f"{predicted_diameters[eye_type][0]:.2f}" | |
frame = overlay_text_on_frame(frame, text) | |
video_cols[i].image(frame, use_column_width=True) | |
elif is_video(file_extension): | |
output_video_path = f"{root_path}/tmp.webm" | |
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video( | |
cols, | |
video_frames, | |
tv_model, | |
pupil_selection, | |
output_video_path, | |
cam_method=CAM_METHODS[-1], | |
blink_detection=blink_detection, | |
) | |
os.remove(video_path) | |
num_columns = len(predicted_diameters) | |
# Create a layout for the charts | |
cols = st.columns(num_columns) | |
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange | |
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray | |
# Iterate through categories and assign charts to columns | |
for i, (category, values) in enumerate(predicted_diameters.items()): | |
with cols[i]: # Directly use the column index | |
# st.subheader(category) # Add a subheader for the category | |
# Convert values to numeric, replacing non-numeric values with None | |
values = [convert_diameter(value) for value in values] | |
# Create a DataFrame from the values for Altair | |
df = pd.DataFrame(values, columns=[category]) | |
df["Frame"] = range(1, len(values) + 1) # Create a frame column starting from 1 | |
# Get the min and max values for y-axis limits, ignoring None | |
min_value = min(filter(lambda x: x is not None, values), default=None) | |
max_value = max(filter(lambda x: x is not None, values), default=None) | |
# Create an Altair chart with y-axis limits | |
line_chart = ( | |
alt.Chart(df) | |
.mark_line(color=colors[i]) | |
.encode( | |
x=alt.X("Frame:Q", title="Frame Number"), | |
y=alt.Y( | |
f"{category}:Q", | |
title="Diameter", | |
scale=alt.Scale(domain=[min_value, max_value]), | |
), | |
tooltip=[ | |
"Frame", | |
alt.Tooltip(f"{category}:Q", title="Diameter"), | |
], | |
) | |
# .properties(title=f"{category} - Predicted Diameters") | |
# .configure_axis(grid=True) | |
) | |
points_chart = line_chart.mark_point(color=colors[i], filled=True) | |
final_chart = ( | |
line_chart.properties(title=f"{category} - Predicted Diameters") + points_chart | |
).interactive() | |
final_chart = final_chart.configure_axis(grid=True) | |
# Display the Altair chart | |
st.altair_chart(final_chart, use_container_width=True) | |
if eyes_ratios is not None and len(eyes_ratios) > 0: | |
df = pd.DataFrame(eyes_ratios, columns=["EAR"]) | |
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1 | |
# Create an Altair chart for eyes_ratios | |
line_chart = ( | |
alt.Chart(df) | |
.mark_line(color=colors[-1]) # Set color of the line | |
.encode( | |
x=alt.X("Frame:Q", title="Frame Number"), | |
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"), | |
tooltip=["Frame", "EAR"], | |
) | |
# .properties(title="Eyes Aspect Ratios (EARs)") | |
# .configure_axis(grid=True) | |
) | |
points_chart = line_chart.mark_point(color=colors[-1], filled=True) | |
# Create a horizontal rule at y=0.22 | |
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q") | |
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q") | |
# Add text annotations for the lines | |
text1 = ( | |
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]})) | |
.mark_text(align="left", dx=100, dy=9, color="red", size=16) | |
.encode(y="y:Q", text="label:N") | |
) | |
text2 = ( | |
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]})) | |
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16) | |
.encode(y="y:Q", text="label:N") | |
) | |
# Add gray area text for the region between red and green lines | |
gray_area_text = ( | |
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]})) | |
.mark_text(align="left", dx=0, dy=0, color="gray", size=16) | |
.encode(y="y:Q", text="label:N") | |
) | |
# Combine all elements: line chart, points, rules, and text annotations | |
final_chart = ( | |
line_chart.properties(title="Eyes Aspect Ratios (EARs)") | |
+ points_chart | |
+ line1 | |
+ line2 | |
+ text1 | |
+ text2 | |
+ gray_area_text | |
).interactive() | |
# Configure axis properties at the chart level | |
final_chart = final_chart.configure_axis(grid=True) | |
# Display the Altair chart | |
# st.subheader("Eyes Aspect Ratios (EARs)") | |
st.altair_chart(final_chart, use_container_width=True) | |
if __name__ == "__main__": | |
main() | |