fish-classifier / app.py
And00drew's picture
Update app.py
12d721d verified
import gradio as gr
import requests
from PIL import Image
import numpy as np
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Hugging Face API settings
HF_API_URL = "https://api-inference.huggingface.co/models/jeemsterri/fish_classification"
HF_API_KEY = "your_huggingface_api_key" # Replace with your key
def classify_fish(image: Image.Image) -> dict:
"""
Classify a fish image using Hugging Face API or fallback to MobileNet.
Args:
image: PIL Image object.
Returns:
Dict with predictions or error message.
"""
try:
# Convert image to bytes for API
img_bytes = image.tobytes()
# Try Hugging Face API first
headers = {"Authorization": f"Bearer {HF_API_KEY}"}
response = requests.post(HF_API_URL, headers=headers, data=img_bytes)
if response.status_code == 200:
predictions = response.json()
logger.info(f"API response: {predictions}")
return {"source": "Hugging Face", "predictions": predictions}
# Fallback to MobileNet if API fails
logger.warning(f"API failed (status {response.status_code}), using fallback...")
import tensorflow as tf
import tensorflow_hub as hub
# Load MobileNet
model = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4")
])
image = image.resize((224, 224)) # MobileNet expects 224x224
image_array = np.array(image) / 255.0
image_array = np.expand_dims(image_array, axis=0)
predictions = model.predict(image_array)
top_prediction = tf.keras.applications.mobilenet_v2.decode_predictions(predictions, top=1)[0][0]
return {
"source": "MobileNet (Fallback)",
"predictions": [{"label": top_prediction[1], "score": float(top_prediction[2])}]
}
except Exception as e:
logger.error(f"Classification error: {str(e)}")
return {"error": str(e)}
# Gradio Interface
interface = gr.Interface(
fn=classify_fish,
inputs=gr.Image(type="pil", label="Upload Fish Image"),
outputs=gr.JSON(label="Prediction Results"),
title="🐟 Fish Classifier",
description="Upload an image of a fish to see the predicted class probabilities.",
examples=["salmon.jpg", "tuna.jpg"], # Add example images
theme="soft"
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860)