osamaifti's picture
Update app.py
a5b31f5
raw
history blame
No virus
2.95 kB
import gradio as gr
import torch
import torchvision
import transformers
import numpy as np
from scipy.spatial.distance import cosine
import cv2
import os
RECOGNITION_THRESHOLD = 0.3
# Load the embedding model
embedding_model = torch.load('full_mode2.pth', map_location=torch.device('cpu'))
embedding_model.eval() # Set the model to evaluation mode
# Database to store embeddings and user IDs
user_embeddings = {}
# Preprocess the image
def preprocess_image(image):
image = cv2.resize(image, (375, 375)) # Resize image
image = image / 255.0 # Normalize pixel values
image = np.transpose(image, (2, 0, 1)) # Change from HWC to CHW format
return torch.tensor(image, dtype=torch.float32).unsqueeze(0) # Add batch dimension
# Generate embedding
def generate_embedding(image):
preprocessed_image = preprocess_image(image)
with torch.no_grad(): # No need to track gradients
return embedding_model(preprocessed_image).numpy()[0]
# Register new user
def register_user(image, user_id):
try:
embedding = generate_embedding(image)
user_embeddings[user_id] = embedding
return f"User {user_id} registered successfully."
except Exception as e:
return f"Error during registration: {str(e)}"
# Recognize user
def recognize_user(image):
try:
new_embedding = generate_embedding(image)
min_distance = float('inf')
recognized_user_id = "Unknown"
for user_id, embedding in user_embeddings.items():
distance = cosine(new_embedding, embedding)
print(f"Distance for {user_id}: {distance}") # Debug: Print distances
if distance < min_distance:
min_distance = distance
recognized_user_id = user_id
print(f"Min distance: {min_distance}") # Debug: Print minimum distance
if min_distance > RECOGNITION_THRESHOLD:
return "User not recognized."
else:
return f"Recognized User: {recognized_user_id}"
except Exception as e:
return f"Error during recognition: {str(e)}"
def main():
with gr.Blocks() as demo:
gr.Markdown("Facial Recognition System")
with gr.Tab("Register"):
with gr.Row():
img_register = gr.Image()
user_id = gr.Textbox(label="User ID")
register_button = gr.Button("Register")
register_output = gr.Textbox()
register_button.click(register_user, inputs=[img_register, user_id], outputs=register_output)
with gr.Tab("Recognize"):
with gr.Row():
img_recognize = gr.Image()
recognize_button = gr.Button("Recognize")
recognize_output = gr.Textbox()
recognize_button.click(recognize_user, inputs=[img_recognize], outputs=recognize_output)
demo.launch(share=True)
if __name__ == "__main__":
main()