autonomous019's picture
import gradio
8f420ad
raw
history blame
733 Bytes
from transformers import ViTConfig, ViTForImageClassification
from transformers import ViTFeatureExtractor
from PIL import Image
import requests
import matplotlib.pyplot as plt
import gradio as gr
# option 1: load with randomly initialized weights (train from scratch)
config = ViTConfig(num_hidden_layers=12, hidden_size=768)
model = ViTForImageClassification(config)
print(config)
feature_extractor = ViTFeatureExtractor()
# or, to load one that corresponds to a checkpoint on the hub:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image.save("cats.png")
image