osamaifti's picture
Update app.py
eed8f27
raw
history blame contribute delete
No virus
2.75 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from scipy.spatial.distance import cosine
# Constants
RECOGNITION_THRESHOLD = 0.8
# Load the model
model_path = 'final_modelnew.pth'
model = torch.load(model_path, map_location=torch.device('cpu'))
model.eval() # Set the model to evaluation mode
# Database to store embeddings and user IDs
user_embeddings = {}
# Preprocess the image
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image).unsqueeze(0)
return image
# Generate embedding
def generate_embedding(image):
preprocessed_image = preprocess_image(image)
with torch.no_grad(): # No need to track gradients
embedding = model(preprocessed_image)
return embedding.numpy()[0]
# Register new user
def register_user(image, user_id):
try:
embedding = generate_embedding(image)
user_embeddings[user_id] = embedding
return f"User {user_id} registered successfully."
except Exception as e:
return f"Error during registration: {str(e)}"
# Recognize user
def recognize_user(image):
try:
new_embedding = generate_embedding(image)
min_distance = float('inf')
recognized_user_id = "Unknown"
for user_id, embedding in user_embeddings.items():
distance = cosine(new_embedding, embedding)
if distance < min_distance:
min_distance = distance
recognized_user_id = user_id
if min_distance > RECOGNITION_THRESHOLD:
return "User not recognized."
else:
return f"Recognized User: {recognized_user_id}"
except Exception as e:
return f"Error during recognition: {str(e)}"
def main():
with gr.Blocks() as demo:
gr.Markdown("Facial Recognition System")
with gr.Tab("Register"):
with gr.Row():
img_register = gr.Image()
user_id = gr.Textbox(label="User ID")
register_button = gr.Button("Register")
register_output = gr.Textbox()
register_button.click(register_user, inputs=[img_register, user_id], outputs=register_output)
with gr.Tab("Recognize"):
with gr.Row():
img_recognize = gr.Image()
recognize_button = gr.Button("Recognize")
recognize_output = gr.Textbox()
recognize_button.click(recognize_user, inputs=[img_recognize], outputs=recognize_output)
demo.launch(share=True)
if __name__ == "__main__":
main()