itzRahul commited on
Commit
186d3d2
·
1 Parent(s): 29ea379

Add all source files

Browse files
ViT_Caltech101_five_epochs.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fe3cd9c70bf0532eae9315af3c6db81eea1c05a693bd400fee200a618774497
3
+ size 343568662
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import create_vit_instance
3
+ from pathlib import Path
4
+ import torch
5
+ from PIL import Image
6
+ from typing import List, Dict, Tuple
7
+ from timeit import default_timer as timer
8
+
9
+
10
+ # Reading all available classes
11
+ with open('class_names.txt', 'r') as f:
12
+ all_classes = [name.replace('\n', '') for name in f.readlines()]
13
+
14
+ demo_vit_model, demo_vit_transforms = create_vit_instance(num_classes=len(all_classes),
15
+ device='cpu')
16
+
17
+ weights_path = Path("ViT_Caltech101_five_epochs.pth")
18
+ demo_vit_model.load_state_dict(torch.load(f=weights_path,
19
+ map_location='cpu'))
20
+
21
+
22
+ ## Creating predict method => It returns prediction probability dictionary as well as time taken to do the prediction
23
+ def predict(img_path: str,
24
+ model:torch.nn.Module=demo_vit_model,
25
+ transform: torchvision.transforms=demo_vit_transforms,
26
+ classes:List[str] = all_classes)->Tuple[Dict, int]:
27
+
28
+ pred_prob_dict = dict()
29
+ model = model.to('cpu')
30
+ # img_path = Image.open(img_path)
31
+ transformed_image = transform(img_path)
32
+
33
+ start = timer()
34
+ model.eval()
35
+ with torch.inference_mode():
36
+
37
+ batch_img = transformed_image.unsqueeze(dim=0).to(device='cpu')
38
+ logit = model(batch_img)
39
+ pred_probs = torch.softmax(input=logit,
40
+ dim=1)
41
+ preds = torch.argmax(input=pred_probs,
42
+ dim=1).item()
43
+ end = timer()
44
+
45
+ total_time = round(end - start, 4)
46
+ pred_probs = pred_probs[0].tolist()
47
+
48
+ for idx in range(len(pred_probs)):
49
+ class_name = classes[idx]
50
+ pred_prob_dict[class_name] = pred_probs[idx]
51
+
52
+ sorted_order = sorted(pred_prob_dict.items(), key=lambda kv: kv[1], reverse=True)
53
+
54
+ return (pred_prob_dict, total_time)
55
+
56
+
57
+
58
+ title = "ObjectVision"
59
+ description = "ViT Feature Extractor trained for Image Classification based on Caltech101 dataset."
60
+ samples = [[path] for path in Path("examples").iterdir()]
61
+ demo = gr.Interface(fn=predict,
62
+ title=title,
63
+ description=description,
64
+ inputs=gr.Image(type="pil"),
65
+ examples=samples,
66
+ outputs=[
67
+ gr.Label(num_top_classes=5,
68
+ label="Model thinks"),
69
+ gr.Number(label="Prediction time (in seconds)")
70
+ ])
71
+
72
+ if __name__ == "__main__":
73
+ demo.launch(debug=True)
74
+
class_names.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Faces
2
+ Faces_easy
3
+ Leopards
4
+ Motorbikes
5
+ accordion
6
+ airplanes
7
+ anchor
8
+ ant
9
+ barrel
10
+ bass
11
+ beaver
12
+ binocular
13
+ bonsai
14
+ brain
15
+ brontosaurus
16
+ buddha
17
+ butterfly
18
+ camera
19
+ cannon
20
+ car_side
21
+ ceiling_fan
22
+ cellphone
23
+ chair
24
+ chandelier
25
+ cougar_body
26
+ cougar_face
27
+ crab
28
+ crayfish
29
+ crocodile
30
+ crocodile_head
31
+ cup
32
+ dalmatian
33
+ dollar_bill
34
+ dolphin
35
+ dragonfly
36
+ electric_guitar
37
+ elephant
38
+ emu
39
+ euphonium
40
+ ewer
41
+ ferry
42
+ flamingo
43
+ flamingo_head
44
+ garfield
45
+ gerenuk
46
+ gramophone
47
+ grand_piano
48
+ hawksbill
49
+ headphone
50
+ hedgehog
51
+ helicopter
52
+ ibis
53
+ inline_skate
54
+ joshua_tree
55
+ kangaroo
56
+ ketch
57
+ lamp
58
+ laptop
59
+ llama
60
+ lobster
61
+ lotus
62
+ mandolin
63
+ mayfly
64
+ menorah
65
+ metronome
66
+ minaret
67
+ nautilus
68
+ octopus
69
+ okapi
70
+ pagoda
71
+ panda
72
+ pigeon
73
+ pizza
74
+ platypus
75
+ pyramid
76
+ revolver
77
+ rhino
78
+ rooster
79
+ saxophone
80
+ schooner
81
+ scissors
82
+ scorpion
83
+ sea_horse
84
+ snoopy
85
+ soccer_ball
86
+ stapler
87
+ starfish
88
+ stegosaurus
89
+ stop_sign
90
+ strawberry
91
+ sunflower
92
+ tick
93
+ trilobite
94
+ umbrella
95
+ watch
96
+ water_lilly
97
+ wheelchair
98
+ wild_cat
99
+ windsor_chair
100
+ wrench
101
+ yin_yang
examples/image_0012.jpg ADDED
examples/image_0014.jpg ADDED
examples/image_0036.jpg ADDED
examples/image_0171.jpg ADDED
examples/image_0225.jpg ADDED
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from collections import OrderedDict
5
+ from torchvision import transforms
6
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
7
+
8
+ def create_vit_instance(num_classes:int = 1000,
9
+ device:torch.device = 'cpu'):
10
+ vit_weight = ViT_B_16_Weights.DEFAULT
11
+
12
+ vit_transforms = vit_weight.transforms()
13
+
14
+ vit_model = vit_b_16(weights=vit_weight).to(device)
15
+
16
+ for param in vit_model.parameters():
17
+ param.requires_grad = False
18
+
19
+ vit_model.heads = nn.Sequential(
20
+ OrderedDict([
21
+ ('head', nn.Linear(in_features=768,
22
+ out_features=num_classes))
23
+ ])
24
+ ).to(device)
25
+
26
+ transform = transforms.Compose([
27
+ transforms.Resize(256, interpolation=InterpolationMode.BILINEAR),
28
+ transforms.CenterCrop(224),
29
+ transforms.Grayscale(num_output_channels=3), # Convert grayscale to RGB
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ ])
33
+
34
+
35
+ return (vit_model, transform)
36
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ gradio==5.0.2
3
+ torch==2.4.0
4
+ torchvision=0.19.0