ombhojane's picture
Update app.py
2b2e538 verified
raw
history blame
5.42 kB
import streamlit as st
import os
import glob
import numpy as np
import cv2
from deepface import DeepFace
from scipy.spatial.distance import cosine
import matplotlib.pyplot as plt
from PIL import Image
import tempfile
st.set_page_config(page_title="Celebrity Lookalike Finder", layout="wide")
# Styling
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stTitle {
text-align: center;
}
</style>
""", unsafe_allow_html=True)
# Title
st.title("🌟 Celebrity Lookalike Finder")
st.write("Upload your photo to find your celebrity doppelganger!")
def detect_and_align_face(img_path):
"""Detect face and align it using OpenCV's face detector"""
try:
img = cv2.imread(img_path)
if img is None:
return None
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) == 0:
return img
x, y, w, h = faces[0]
margin = 30
y = max(0, y - margin)
h = min(img.shape[0] - y, h + 2*margin)
x = max(0, x - margin)
w = min(img.shape[1] - x, w + 2*margin)
face = img[y:y+h, x:x+w]
face = cv2.resize(face, (224, 224))
return face
except Exception as e:
st.error(f"Error in face detection: {str(e)}")
return img
def extract_features(img_path):
"""Extract features using DeepFace"""
try:
embedding = DeepFace.represent(
img_path=img_path,
model_name="VGG-Face",
enforce_detection=False,
detector_backend="opencv"
)
if isinstance(embedding, list):
embedding = embedding[0]
if isinstance(embedding, dict):
if 'embedding' in embedding:
return np.array(embedding['embedding'])
else:
for value in embedding.values():
if isinstance(value, (list, np.ndarray)):
return np.array(value).flatten()
if isinstance(embedding, np.ndarray):
return embedding.flatten()
st.warning(f"Unexpected embedding type: {type(embedding)}")
return None
except Exception as e:
st.error(f"Error in feature extraction: {str(e)}")
return None
@st.cache_data
def build_celebrity_database():
"""Build and cache celebrity database"""
celebrity_paths = glob.glob('data/*.*')
celebrity_features = []
celebrity_paths_list = []
progress_bar = st.progress(0)
status_text = st.empty()
for i, img_path in enumerate(celebrity_paths):
status_text.text(f"Processing image {i+1}/{len(celebrity_paths)}")
features = extract_features(img_path)
if features is not None:
celebrity_features.append(features)
celebrity_paths_list.append(img_path)
progress_bar.progress((i + 1) / len(celebrity_paths))
status_text.text("Database built successfully!")
return celebrity_features, celebrity_paths_list
def find_matches(user_features, celebrity_features, celebrity_paths, top_n=5):
"""Find celebrity matches"""
similarities = []
for celeb_feature in celebrity_features:
if user_features.shape != celeb_feature.shape:
continue
similarity = 1 - cosine(user_features, celeb_feature)
similarities.append(similarity)
if not similarities:
st.warning("No valid comparisons could be made")
return
top_indices = np.argsort(similarities)[-top_n:][::-1]
# Display results in columns
cols = st.columns(top_n)
for i, (idx, col) in enumerate(zip(top_indices, cols)):
with col:
celeb_img = Image.open(celebrity_paths[idx])
st.image(celeb_img, caption=f"Match {i+1}\nSimilarity: {similarities[idx]:.2%}")
def main():
# Load celebrity database
with st.spinner("Building celebrity database..."):
celebrity_features, celebrity_paths = build_celebrity_database()
# File uploader
uploaded_file = st.file_uploader("Choose a photo", type=['jpg', 'jpeg', 'png'])
if uploaded_file is not None:
# Create columns for side-by-side display
col1, col2 = st.columns(2)
with col1:
st.subheader("Your Photo")
st.image(uploaded_file)
# Process the uploaded image
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_path = tmp_file.name
# Extract features and find matches
with st.spinner("Finding your celebrity matches..."):
user_features = extract_features(tmp_path)
if user_features is not None:
with col2:
st.subheader("Your Celebrity Matches")
find_matches(user_features, celebrity_features, celebrity_paths)
else:
st.error("Could not process the uploaded image. Please try another photo.")
# Clean up
os.unlink(tmp_path)
if __name__ == "__main__":
main()