slr-easz / app.py
Niharmahesh's picture
Update app.py
04eb0d0 verified
raw
history blame
6.44 kB
import streamlit as st
import cv2
import numpy as np
import mediapipe as mp
import joblib
import pandas as pd
from numpy.linalg import norm
import matplotlib.pyplot as plt
import os
import base64
st.set_page_config(layout="wide")
# Function to load the Random Forest model
@st.cache_resource
def load_model():
try:
return joblib.load('best_random_forest_model.pkl')
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# Load the model using the cached function
model = load_model()
# Ensure the model is loaded before proceeding
if model is None:
st.stop()
# Function to normalize landmarks
def normalize_landmarks(landmarks):
center = np.mean(landmarks, axis=0)
landmarks_centered = landmarks - center
std_dev = np.std(landmarks_centered, axis=0)
landmarks_normalized = landmarks_centered / std_dev
return np.nan_to_num(landmarks_normalized)
# Function to calculate angles between landmarks
def calculate_angles(landmarks):
angles = []
for i in range(20):
for j in range(i + 1, 21):
vector = landmarks[j] - landmarks[i]
angle_x = np.arccos(np.clip(vector[0] / norm(vector), -1.0, 1.0))
angle_y = np.arccos(np.clip(vector[1] / norm(vector), -1.0, 1.0))
angles.extend([angle_x, angle_y])
return angles
# Function to process image and predict alphabet
def process_and_predict(image):
mp_hands = mp.solutions.hands
with mp_hands.Hands(static_image_mode=True, max_num_hands=1, min_detection_confidence=0.5) as hands:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = hands.process(image_rgb)
if results.multi_hand_landmarks:
landmarks = np.array([[lm.x, lm.y] for lm in results.multi_hand_landmarks[0].landmark])
landmarks_normalized = normalize_landmarks(landmarks)
angles = calculate_angles(landmarks_normalized)
angle_columns = [f'angle_{i}' for i in range(len(angles))]
angles_df = pd.DataFrame([angles], columns=angle_columns)
probabilities = model.predict_proba(angles_df)[0]
return probabilities, landmarks
return None, None
# Function to plot hand landmarks
def plot_hand_landmarks(landmarks, title):
fig, ax = plt.subplots(figsize=(10, 10))
ax.scatter(landmarks[:, 0], landmarks[:, 1], c='blue', s=50)
mp_hands = mp.solutions.hands
for connection in mp_hands.HAND_CONNECTIONS:
start_idx = connection[0]
end_idx = connection[1]
ax.plot([landmarks[start_idx, 0], landmarks[end_idx, 0]],
[landmarks[start_idx, 1], landmarks[end_idx, 1]], 'r-', linewidth=2)
ax.invert_yaxis()
ax.set_title(title, fontsize=16)
ax.axis('off')
return fig
# Function to create a download link for the README file
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
# Streamlit app
st.title("ASL Recognition App")
# Add README button
readme_col1, readme_col2 = st.columns([1, 3])
with readme_col1:
st.markdown("## How it works")
with readme_col2:
st.markdown(get_binary_file_downloader_html('readme.md', 'README'), unsafe_allow_html=True)
# Create tabs for different functionalities
tab1, tab2 = st.tabs(["Predict ASL Sign", "View Hand Landmarks"])
with tab1:
st.header("Predict ASL Sign")
uploaded_file = st.file_uploader("Upload an image of an ASL sign", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), 1)
if image is not None:
st.image(image, caption="Uploaded Image", use_column_width=True)
probabilities, landmarks = process_and_predict(image)
if probabilities is not None and landmarks is not None:
st.subheader("Top 5 Predictions:")
top_indices = np.argsort(probabilities)[::-1][:5]
for i in top_indices:
st.write(f"{model.classes_[i]}: {probabilities[i]:.2f}")
fig = plot_hand_landmarks(landmarks, "Detected Hand Landmarks")
st.pyplot(fig)
else:
st.write("No hand detected in the image.")
else:
st.error("Failed to load the image. The file might be corrupted.")
except Exception as e:
st.error(f"An error occurred while processing the image: {str(e)}")
with tab2:
st.header("View Hand Landmarks")
all_alphabets = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
excluded_alphabets = 'DMNPTUVXZ'
available_alphabets = ''.join(set(all_alphabets) - set(excluded_alphabets))
selected_alphabets = st.multiselect("Select alphabets to view landmarks:", list(available_alphabets))
if selected_alphabets:
cols = st.columns(min(3, len(selected_alphabets)))
for idx, alphabet in enumerate(selected_alphabets):
with cols[idx % 3]:
image_path = os.path.join('asl test set', f'{alphabet.lower()}.jpeg')
st.write(f"Attempting to load: {image_path}")
if os.path.exists(image_path):
try:
image = cv2.imread(image_path)
if image is not None:
probabilities, landmarks = process_and_predict(image)
if landmarks is not None:
fig = plot_hand_landmarks(landmarks, f"Hand Landmarks for {alphabet}")
st.pyplot(fig)
else:
st.error(f"No hand detected for {alphabet}")
else:
st.error(f"Failed to load image for {alphabet}. The file might be corrupted.")
except Exception as e:
st.error(f"An error occurred while processing image for {alphabet}: {str(e)}")
else:
st.error(f"Image not found for {alphabet}")