arsath-sm's picture
Update app.py
9552979 verified
raw
history blame
No virus
3.08 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from huggingface_hub import hf_hub_download, list_repo_files
def list_files_in_repo(repo_id):
try:
files = list_repo_files(repo_id)
print(f"Files in {repo_id}:")
for file in files:
print(file)
return files
except Exception as e:
print(f"Error listing files in {repo_id}: {str(e)}")
return []
def load_model_from_hub(repo_id):
files = list_files_in_repo(repo_id)
model_file = next((f for f in files if f.endswith('.h5') or f.endswith('.keras')), None)
if model_file is None:
raise ValueError(f"No .h5 or .keras file found in {repo_id}")
try:
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
return tf.keras.models.load_model(model_path)
except Exception as e:
print(f"Error loading model from {repo_id}: {str(e)}")
raise
# Try to load models
try:
print("Attempting to load Model 1...")
model1 = load_model_from_hub("arsath-sm/face_classification_model1")
print("Model 1 loaded successfully.")
except Exception as e:
print(f"Failed to load Model 1: {str(e)}")
model1 = None
try:
print("Attempting to load Model 2...")
model2 = load_model_from_hub("arsath-sm/face_classification_model2")
print("Model 2 loaded successfully.")
except Exception as e:
print(f"Failed to load Model 2: {str(e)}")
model2 = None
def preprocess_image(image):
img = tf.convert_to_tensor(image)
img = tf.image.resize(img, (150, 150))
img = img / 255.0
return tf.expand_dims(img, 0)
def predict_image(image):
if model1 is None and model2 is None:
return {
"Error": "Both models failed to load. Please check the model repositories and try again."
}
preprocessed_image = preprocess_image(image)
results = {}
if model1 is not None:
pred1 = model1.predict(preprocessed_image)[0][0]
result1 = "Real" if pred1 > 0.5 else "Fake"
confidence1 = pred1 if pred1 > 0.5 else 1 - pred1
results["Model 1 Prediction"] = f"{result1} (Confidence: {confidence1:.2f})"
else:
results["Model 1 Prediction"] = "Model failed to load"
if model2 is not None:
pred2 = model2.predict(preprocessed_image)[0][0]
result2 = "Real" if pred2 > 0.5 else "Fake"
confidence2 = pred2 if pred2 > 0.5 else 1 - pred2
results["Model 2 Prediction"] = f"{result2} (Confidence: {confidence2:.2f})"
else:
results["Model 2 Prediction"] = "Model failed to load"
return results
# Create the Gradio interface
iface = gr.Interface(
fn=predict_image,
inputs=gr.Image(),
outputs={
"Model 1 Prediction": gr.Textbox(),
"Model 2 Prediction": gr.Textbox(),
"Error": gr.Textbox()
},
title="Real vs AI Face Classification",
description="Upload an image to classify whether it's a real face or an AI-generated face using two different models."
)
# Launch the app
iface.launch()