File size: 607 Bytes
d66d160
8090b75
 
 
4625865
8f420ad
 
4625865
8090b75
d66d160
 
 
 
 
 
 
 
 
 
 
3ce0ef7
d66d160
10f177d
 
d66d160
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import ViTConfig, ViTForImageClassification
from transformers import ViTFeatureExtractor
from PIL import Image
import requests
import matplotlib.pyplot as plt
import gradio as gr




# option 1: load with randomly initialized weights (train from scratch)

config = ViTConfig(num_hidden_layers=12, hidden_size=768)
model = ViTForImageClassification(config)

print(config)

feature_extractor = ViTFeatureExtractor()

# or, to load one that corresponds to a checkpoint on the hub:
#feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")


image = "cats.jpg"