|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from PIL import Image |
|
import requests |
|
import gradio as gr |
|
import os |
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') |
|
|
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
|
|
def inference(image): |
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
predicted_class_idx = logits.argmax(-1).item() |
|
print(type(model.config.id2label[predicted_class_idx])) |
|
return "Predicted class:"+model.config.id2label[predicted_class_idx] |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown( |
|
""" |
|
# Welcome to this Replit Template for Gradio! |
|
Start by adding a image, this demo uses google/vit-base-patch16-224 model from Hugging Face model Hub for a image classification demo, for more details read the [model card on Hugging Face](https://huggingface.co/google/vit-base-patch16-224) |
|
""") |
|
inp = gr.Image(type="pil") |
|
out = gr.Label() |
|
|
|
button = gr.Button(value="Run") |
|
gr.Examples( |
|
examples=[os.path.join(os.path.dirname(__file__), "lion.jpeg")], |
|
inputs=inp, |
|
outputs=out, |
|
fn=inference, |
|
cache_examples=False) |
|
|
|
button.click(fn=inference, |
|
inputs=inp, |
|
outputs=out) |
|
|
|
|
|
|
|
demo.launch(share=True) |