desc / app.py
PhilHolst's picture
Create app.py
b9e51b5
raw
history blame
No virus
1.45 kB
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import gradio as gr
import os
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
def inference(image):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print(type(model.config.id2label[predicted_class_idx]))
return "Predicted class:"+model.config.id2label[predicted_class_idx]
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Welcome to this Replit Template for Gradio!
Start by adding a image, this demo uses google/vit-base-patch16-224 model from Hugging Face model Hub for a image classification demo, for more details read the [model card on Hugging Face](https://huggingface.co/google/vit-base-patch16-224)
""")
inp = gr.Image(type="pil")
out = gr.Label()
button = gr.Button(value="Run")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "lion.jpeg")],
inputs=inp,
outputs=out,
fn=inference,
cache_examples=False)
button.click(fn=inference,
inputs=inp,
outputs=out)
demo.launch(share=True)