Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import glob | |
import numpy as np | |
import cv2 | |
from deepface import DeepFace | |
from scipy.spatial.distance import cosine | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import tempfile | |
st.set_page_config(page_title="Celebrity Lookalike Finder", layout="wide") | |
# Styling | |
st.markdown(""" | |
<style> | |
.main { | |
padding: 2rem; | |
} | |
.stTitle { | |
text-align: center; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Title | |
st.title("π Celebrity Lookalike Finder") | |
st.write("Upload your photo to find your celebrity doppelganger!") | |
def detect_and_align_face(img_path): | |
"""Detect face and align it using OpenCV's face detector""" | |
try: | |
img = cv2.imread(img_path) | |
if img is None: | |
return None | |
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
faces = face_cascade.detectMultiScale(gray, 1.1, 4) | |
if len(faces) == 0: | |
return img | |
x, y, w, h = faces[0] | |
margin = 30 | |
y = max(0, y - margin) | |
h = min(img.shape[0] - y, h + 2*margin) | |
x = max(0, x - margin) | |
w = min(img.shape[1] - x, w + 2*margin) | |
face = img[y:y+h, x:x+w] | |
face = cv2.resize(face, (224, 224)) | |
return face | |
except Exception as e: | |
st.error(f"Error in face detection: {str(e)}") | |
return img | |
def extract_features(img_path): | |
"""Extract features using DeepFace""" | |
try: | |
embedding = DeepFace.represent( | |
img_path=img_path, | |
model_name="VGG-Face", | |
enforce_detection=False, | |
detector_backend="opencv" | |
) | |
if isinstance(embedding, list): | |
embedding = embedding[0] | |
if isinstance(embedding, dict): | |
if 'embedding' in embedding: | |
return np.array(embedding['embedding']) | |
else: | |
for value in embedding.values(): | |
if isinstance(value, (list, np.ndarray)): | |
return np.array(value).flatten() | |
if isinstance(embedding, np.ndarray): | |
return embedding.flatten() | |
st.warning(f"Unexpected embedding type: {type(embedding)}") | |
return None | |
except Exception as e: | |
st.error(f"Error in feature extraction: {str(e)}") | |
return None | |
def build_celebrity_database(): | |
"""Build and cache celebrity database""" | |
celebrity_paths = glob.glob('data/*.*') | |
celebrity_features = [] | |
celebrity_paths_list = [] | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
for i, img_path in enumerate(celebrity_paths): | |
status_text.text(f"Processing image {i+1}/{len(celebrity_paths)}") | |
features = extract_features(img_path) | |
if features is not None: | |
celebrity_features.append(features) | |
celebrity_paths_list.append(img_path) | |
progress_bar.progress((i + 1) / len(celebrity_paths)) | |
status_text.text("Database built successfully!") | |
return celebrity_features, celebrity_paths_list | |
def find_matches(user_features, celebrity_features, celebrity_paths, top_n=5): | |
"""Find celebrity matches""" | |
similarities = [] | |
for celeb_feature in celebrity_features: | |
if user_features.shape != celeb_feature.shape: | |
continue | |
similarity = 1 - cosine(user_features, celeb_feature) | |
similarities.append(similarity) | |
if not similarities: | |
st.warning("No valid comparisons could be made") | |
return | |
top_indices = np.argsort(similarities)[-top_n:][::-1] | |
# Display results in columns | |
cols = st.columns(top_n) | |
for i, (idx, col) in enumerate(zip(top_indices, cols)): | |
with col: | |
celeb_img = Image.open(celebrity_paths[idx]) | |
st.image(celeb_img, caption=f"Match {i+1}\nSimilarity: {similarities[idx]:.2%}") | |
def main(): | |
# Load celebrity database | |
with st.spinner("Building celebrity database..."): | |
celebrity_features, celebrity_paths = build_celebrity_database() | |
# File uploader | |
uploaded_file = st.file_uploader("Choose a photo", type=['jpg', 'jpeg', 'png']) | |
if uploaded_file is not None: | |
# Create columns for side-by-side display | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Your Photo") | |
st.image(uploaded_file) | |
# Process the uploaded image | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
tmp_path = tmp_file.name | |
# Extract features and find matches | |
with st.spinner("Finding your celebrity matches..."): | |
user_features = extract_features(tmp_path) | |
if user_features is not None: | |
with col2: | |
st.subheader("Your Celebrity Matches") | |
find_matches(user_features, celebrity_features, celebrity_paths) | |
else: | |
st.error("Could not process the uploaded image. Please try another photo.") | |
# Clean up | |
os.unlink(tmp_path) | |
if __name__ == "__main__": | |
main() |