ioanasong's picture
added webcam file for huggingface
5dd71fd
raw
history blame contribute delete
No virus
1.86 kB
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import torch
import gradio as gr
from torch.nn import functional as F
# gr.load("models/ioanasong/vit-MINC-2500").launch()
# Load the pre-trained ViT model and feature extractor
model_name = "ioanasong/vit-MINC-2500"
model = ViTForImageClassification.from_pretrained(model_name)
model.eval()
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the prediction function
# def predict(image):
# print(image)
# # Preprocess the image
# inputs = feature_extractor(images=image, return_tensors="pt")
# # Make prediction
# with torch.no_grad():
# outputs = model(**inputs)
# logits = outputs.logits
# # Get predicted label
# predicted_class_idx = logits.argmax(-1).item()
# predicted_label = model.config.id2label[predicted_class_idx]
# return predicted_label
def predict(image):
# Preprocess the image using the feature extractor
inputs = feature_extractor(images=image, return_tensors="pt")
# Make prediction using the model
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Compute softmax probabilities
probs = F.softmax(logits, dim=-1)[0]
# Create a dictionary of label and probability
prob_dict = {model.config.id2label[i]: prob.item() for i, prob in enumerate(probs)}
return prob_dict
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(sources=['webcam'], streaming = True),
# outputs=gr.Label(num_top_classes=len(model.config.id2label)),
outputs=gr.Label(num_top_classes=5),
title="ViT Image Classification",
description="Capture an image from the camera and classify it using a pre-trained Vision Transformer (ViT) model.",
)
# Launch the app
iface.launch()