project1 / projekikan
Ikmalhakim's picture
Create projekikan
3071827 verified
raw
history blame contribute delete
No virus
974 Bytes
import gradio as gr
from PIL import Image
import requests
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
# Load the pre-trained model and feature extractor
model_id = "google/vit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
def predict(image):
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Model prediction
with torch.no_grad():
logits = model(**inputs).logits
# Convert logits to probabilities
probs = logits.softmax(dim=-1)
# Get the predicted label and its probability
predicted_label = model.config.id2label[probs.argmax().item()]
probability = probs.max().item()
return {predicted_label: float(probability)}
# Define the Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="label")
iface.launch()