halictus's picture
Update app.py
d448e99 verified
raw
history blame
1.43 kB
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from transformers import pipeline
import torch
# Cache the model loading
model_id = "Honey-Bee-Society/honeybee_bumblebee_vespidae_resnet50"
classifier = pipeline("image-classification", model=model_id, device=0 if torch.cuda.is_available() else -1)
def classify_image_from_url(image_url: str):
"""
Downloads an image from a public URL and runs it through
the ResNet-50 fine-tuned image-classification pipeline, returning the top predictions.
"""
try:
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
# Run inference
results = classifier(image)
# Format scores to remove scientific notation
for r in results:
r["score"] = float(f"{r['score']:.8f}")
return results
except requests.exceptions.RequestException as e:
return {"error": f"Failed to download image: {str(e)}"}
except Exception as e:
return {"error": f"An error occurred during classification: {str(e)}"}
demo = gr.Interface(
fn=classify_image_from_url,
inputs=gr.Textbox(lines=1, label="Image URL"),
outputs="json",
title="ResNet-50 Image Classifier",
description="Enter public image URL to get top predictions."
)
if __name__ == "__main__":
demo.launch()