Spaces:
Sleeping
Sleeping
Add all source files
Browse files- ViT_Caltech101_five_epochs.pth +3 -0
- app.py +74 -0
- class_names.txt +101 -0
- examples/image_0012.jpg +0 -0
- examples/image_0014.jpg +0 -0
- examples/image_0036.jpg +0 -0
- examples/image_0171.jpg +0 -0
- examples/image_0225.jpg +0 -0
- model.py +36 -0
- requirements.txt +4 -0
ViT_Caltech101_five_epochs.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fe3cd9c70bf0532eae9315af3c6db81eea1c05a693bd400fee200a618774497
|
3 |
+
size 343568662
|
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from model import create_vit_instance
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from typing import List, Dict, Tuple
|
7 |
+
from timeit import default_timer as timer
|
8 |
+
|
9 |
+
|
10 |
+
# Reading all available classes
|
11 |
+
with open('class_names.txt', 'r') as f:
|
12 |
+
all_classes = [name.replace('\n', '') for name in f.readlines()]
|
13 |
+
|
14 |
+
demo_vit_model, demo_vit_transforms = create_vit_instance(num_classes=len(all_classes),
|
15 |
+
device='cpu')
|
16 |
+
|
17 |
+
weights_path = Path("ViT_Caltech101_five_epochs.pth")
|
18 |
+
demo_vit_model.load_state_dict(torch.load(f=weights_path,
|
19 |
+
map_location='cpu'))
|
20 |
+
|
21 |
+
|
22 |
+
## Creating predict method => It returns prediction probability dictionary as well as time taken to do the prediction
|
23 |
+
def predict(img_path: str,
|
24 |
+
model:torch.nn.Module=demo_vit_model,
|
25 |
+
transform: torchvision.transforms=demo_vit_transforms,
|
26 |
+
classes:List[str] = all_classes)->Tuple[Dict, int]:
|
27 |
+
|
28 |
+
pred_prob_dict = dict()
|
29 |
+
model = model.to('cpu')
|
30 |
+
# img_path = Image.open(img_path)
|
31 |
+
transformed_image = transform(img_path)
|
32 |
+
|
33 |
+
start = timer()
|
34 |
+
model.eval()
|
35 |
+
with torch.inference_mode():
|
36 |
+
|
37 |
+
batch_img = transformed_image.unsqueeze(dim=0).to(device='cpu')
|
38 |
+
logit = model(batch_img)
|
39 |
+
pred_probs = torch.softmax(input=logit,
|
40 |
+
dim=1)
|
41 |
+
preds = torch.argmax(input=pred_probs,
|
42 |
+
dim=1).item()
|
43 |
+
end = timer()
|
44 |
+
|
45 |
+
total_time = round(end - start, 4)
|
46 |
+
pred_probs = pred_probs[0].tolist()
|
47 |
+
|
48 |
+
for idx in range(len(pred_probs)):
|
49 |
+
class_name = classes[idx]
|
50 |
+
pred_prob_dict[class_name] = pred_probs[idx]
|
51 |
+
|
52 |
+
sorted_order = sorted(pred_prob_dict.items(), key=lambda kv: kv[1], reverse=True)
|
53 |
+
|
54 |
+
return (pred_prob_dict, total_time)
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
title = "ObjectVision"
|
59 |
+
description = "ViT Feature Extractor trained for Image Classification based on Caltech101 dataset."
|
60 |
+
samples = [[path] for path in Path("examples").iterdir()]
|
61 |
+
demo = gr.Interface(fn=predict,
|
62 |
+
title=title,
|
63 |
+
description=description,
|
64 |
+
inputs=gr.Image(type="pil"),
|
65 |
+
examples=samples,
|
66 |
+
outputs=[
|
67 |
+
gr.Label(num_top_classes=5,
|
68 |
+
label="Model thinks"),
|
69 |
+
gr.Number(label="Prediction time (in seconds)")
|
70 |
+
])
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
demo.launch(debug=True)
|
74 |
+
|
class_names.txt
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Faces
|
2 |
+
Faces_easy
|
3 |
+
Leopards
|
4 |
+
Motorbikes
|
5 |
+
accordion
|
6 |
+
airplanes
|
7 |
+
anchor
|
8 |
+
ant
|
9 |
+
barrel
|
10 |
+
bass
|
11 |
+
beaver
|
12 |
+
binocular
|
13 |
+
bonsai
|
14 |
+
brain
|
15 |
+
brontosaurus
|
16 |
+
buddha
|
17 |
+
butterfly
|
18 |
+
camera
|
19 |
+
cannon
|
20 |
+
car_side
|
21 |
+
ceiling_fan
|
22 |
+
cellphone
|
23 |
+
chair
|
24 |
+
chandelier
|
25 |
+
cougar_body
|
26 |
+
cougar_face
|
27 |
+
crab
|
28 |
+
crayfish
|
29 |
+
crocodile
|
30 |
+
crocodile_head
|
31 |
+
cup
|
32 |
+
dalmatian
|
33 |
+
dollar_bill
|
34 |
+
dolphin
|
35 |
+
dragonfly
|
36 |
+
electric_guitar
|
37 |
+
elephant
|
38 |
+
emu
|
39 |
+
euphonium
|
40 |
+
ewer
|
41 |
+
ferry
|
42 |
+
flamingo
|
43 |
+
flamingo_head
|
44 |
+
garfield
|
45 |
+
gerenuk
|
46 |
+
gramophone
|
47 |
+
grand_piano
|
48 |
+
hawksbill
|
49 |
+
headphone
|
50 |
+
hedgehog
|
51 |
+
helicopter
|
52 |
+
ibis
|
53 |
+
inline_skate
|
54 |
+
joshua_tree
|
55 |
+
kangaroo
|
56 |
+
ketch
|
57 |
+
lamp
|
58 |
+
laptop
|
59 |
+
llama
|
60 |
+
lobster
|
61 |
+
lotus
|
62 |
+
mandolin
|
63 |
+
mayfly
|
64 |
+
menorah
|
65 |
+
metronome
|
66 |
+
minaret
|
67 |
+
nautilus
|
68 |
+
octopus
|
69 |
+
okapi
|
70 |
+
pagoda
|
71 |
+
panda
|
72 |
+
pigeon
|
73 |
+
pizza
|
74 |
+
platypus
|
75 |
+
pyramid
|
76 |
+
revolver
|
77 |
+
rhino
|
78 |
+
rooster
|
79 |
+
saxophone
|
80 |
+
schooner
|
81 |
+
scissors
|
82 |
+
scorpion
|
83 |
+
sea_horse
|
84 |
+
snoopy
|
85 |
+
soccer_ball
|
86 |
+
stapler
|
87 |
+
starfish
|
88 |
+
stegosaurus
|
89 |
+
stop_sign
|
90 |
+
strawberry
|
91 |
+
sunflower
|
92 |
+
tick
|
93 |
+
trilobite
|
94 |
+
umbrella
|
95 |
+
watch
|
96 |
+
water_lilly
|
97 |
+
wheelchair
|
98 |
+
wild_cat
|
99 |
+
windsor_chair
|
100 |
+
wrench
|
101 |
+
yin_yang
|
examples/image_0012.jpg
ADDED
![]() |
examples/image_0014.jpg
ADDED
![]() |
examples/image_0036.jpg
ADDED
![]() |
examples/image_0171.jpg
ADDED
![]() |
examples/image_0225.jpg
ADDED
![]() |
model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from collections import OrderedDict
|
5 |
+
from torchvision import transforms
|
6 |
+
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
7 |
+
|
8 |
+
def create_vit_instance(num_classes:int = 1000,
|
9 |
+
device:torch.device = 'cpu'):
|
10 |
+
vit_weight = ViT_B_16_Weights.DEFAULT
|
11 |
+
|
12 |
+
vit_transforms = vit_weight.transforms()
|
13 |
+
|
14 |
+
vit_model = vit_b_16(weights=vit_weight).to(device)
|
15 |
+
|
16 |
+
for param in vit_model.parameters():
|
17 |
+
param.requires_grad = False
|
18 |
+
|
19 |
+
vit_model.heads = nn.Sequential(
|
20 |
+
OrderedDict([
|
21 |
+
('head', nn.Linear(in_features=768,
|
22 |
+
out_features=num_classes))
|
23 |
+
])
|
24 |
+
).to(device)
|
25 |
+
|
26 |
+
transform = transforms.Compose([
|
27 |
+
transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
|
28 |
+
transforms.CenterCrop(224),
|
29 |
+
transforms.Grayscale(num_output_channels=3), # Convert grayscale to RGB
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
32 |
+
])
|
33 |
+
|
34 |
+
|
35 |
+
return (vit_model, transform)
|
36 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
gradio==5.0.2
|
3 |
+
torch==2.4.0
|
4 |
+
torchvision=0.19.0
|