riyadifirman's picture
app.py
0d599fc verified
raw
history blame
1.27 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
# Load model and processor
model_name = "riyadifirman/klasifikasiburung"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Define image transformations
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
transform = Compose([
Resize((224, 224)),
ToTensor(),
normalize,
])
def predict(image):
image = Image.fromarray(image)
inputs = transform(image).unsqueeze(0)
outputs = model(inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return processor.decode(predicted_class_idx)
# Create Gradio interface
# In newer versions of Gradio, 'inputs' and 'outputs' are directly
# specified within the gr.Interface constructor.
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"), # Changed from gr.inputs.Image to gr.Image
outputs="text",
title="Bird Classification",
description="Upload an image of a bird to classify it."
)
if __name__ == "__main__":
interface.launch()