Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| from PIL import Image | |
| import numpy as np | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Hugging Face API settings | |
| HF_API_URL = "https://api-inference.huggingface.co/models/jeemsterri/fish_classification" | |
| HF_API_KEY = "your_huggingface_api_key" # Replace with your key | |
| def classify_fish(image: Image.Image) -> dict: | |
| """ | |
| Classify a fish image using Hugging Face API or fallback to MobileNet. | |
| Args: | |
| image: PIL Image object. | |
| Returns: | |
| Dict with predictions or error message. | |
| """ | |
| try: | |
| # Convert image to bytes for API | |
| img_bytes = image.tobytes() | |
| # Try Hugging Face API first | |
| headers = {"Authorization": f"Bearer {HF_API_KEY}"} | |
| response = requests.post(HF_API_URL, headers=headers, data=img_bytes) | |
| if response.status_code == 200: | |
| predictions = response.json() | |
| logger.info(f"API response: {predictions}") | |
| return {"source": "Hugging Face", "predictions": predictions} | |
| # Fallback to MobileNet if API fails | |
| logger.warning(f"API failed (status {response.status_code}), using fallback...") | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| # Load MobileNet | |
| model = tf.keras.Sequential([ | |
| hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4") | |
| ]) | |
| image = image.resize((224, 224)) # MobileNet expects 224x224 | |
| image_array = np.array(image) / 255.0 | |
| image_array = np.expand_dims(image_array, axis=0) | |
| predictions = model.predict(image_array) | |
| top_prediction = tf.keras.applications.mobilenet_v2.decode_predictions(predictions, top=1)[0][0] | |
| return { | |
| "source": "MobileNet (Fallback)", | |
| "predictions": [{"label": top_prediction[1], "score": float(top_prediction[2])}] | |
| } | |
| except Exception as e: | |
| logger.error(f"Classification error: {str(e)}") | |
| return {"error": str(e)} | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=classify_fish, | |
| inputs=gr.Image(type="pil", label="Upload Fish Image"), | |
| outputs=gr.JSON(label="Prediction Results"), | |
| title="🐟 Fish Classifier", | |
| description="Upload an image of a fish to see the predicted class probabilities.", | |
| examples=["salmon.jpg", "tuna.jpg"], # Add example images | |
| theme="soft" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(server_name="0.0.0.0", server_port=7860) |