SavlonBhai's picture
Update app.py
278cdaa verified
"""
Animal Type Classification App
A robust Gradio application for classifying animals using YOLOv8
"""
import gradio as gr
from ultralytics import YOLO
from PIL import Image
import numpy as np
import logging
import sys
import os
from typing import Optional
# Logging Configuration
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Configuration
MODEL_PATH = "best_animal_classifier.pt"
CLASS_NAMES = ["butterfly", "chicken", "elephant", "horse", "spider", "squirrel"]
# Load the model
try:
if os.path.exists(MODEL_PATH):
model = YOLO(MODEL_PATH)
logger.info("βœ… Model loaded successfully!")
else:
logger.error(f"❌ Model file not found at {MODEL_PATH}")
model = None
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
model = None
def classify_animal(image):
if image is None:
return "Please upload an image."
if model is None:
return "Model not loaded. Check server logs."
try:
# Run inference
results = model(image)
# YOLOv8 classification returns a list of results
# We take the top prediction from the first result
result = results[0]
if result.probs is not None:
# Get index of the highest probability
top1_idx = result.probs.top1
conf = result.probs.top1conf.item()
label = result.names[top1_idx]
return f"Prediction: {label.upper()} ({conf:.2%})"
else:
return "No animals detected or classification failed."
except Exception as e:
logger.error(f"Inference error: {e}")
return f"Error during classification: {str(e)}"
# Gradio Interface
demo = gr.Interface(
fn=classify_animal,
inputs=gr.Image(type="pil", label="Upload Animal Image"),
outputs=gr.Textbox(label="Result"),
title="🐾 Animal Type Classifier",
description="Upload a photo of a butterfly, chicken, elephant, horse, spider, or squirrel to identify it.",
examples=[["example_elephant.jpg"]] if os.path.exists("example_elephant.jpg") else None,
cache_examples=False
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)