oschan77 commited on
Commit
5d7161e
·
1 Parent(s): 0514458

first commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ vitb16_v1.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import model
2
+ import os
3
+ import torch
4
+ import gradio as gr
5
+ from timeit import default_timer as timer
6
+ from torch import nn
7
+ from torchvision import transforms
8
+
9
+ class_names = ['chicken', 'elephant' ,'sheep']
10
+
11
+ DEVICE = 'cpu'
12
+
13
+ vit_model, vit_transform = model.create_vitb16_model(
14
+ num_classes=len(class_names)
15
+ )
16
+
17
+ vit_model.load_state_dict(
18
+ torch.load(
19
+ f='vitb16_v1.pth',
20
+ map_location=torch.device(DEVICE),
21
+ )
22
+ )
23
+
24
+ def predict_single_image(image):
25
+ start_time = timer()
26
+ image = vit_transform(image).unsqueeze(0).to(DEVICE)
27
+ vit_model.eval()
28
+ logits = vit_model(image)
29
+ with torch.inference_mode():
30
+ probs = torch.softmax(logits, dim=1)
31
+
32
+ classes_and_probs = {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
33
+ inference_time = round(timer() - start_time, 5)
34
+
35
+ return classes_and_probs, inference_time
36
+
37
+ title = 'AnimalsVision \U0001F413\U0001F418\U0001F411'
38
+ description = 'A ViT computer vision model to classify images of animals as chicken, elephant or sheep.'
39
+ article = 'GitHub Repo: https://github.com/oschan77/AnimalsVision-App'
40
+
41
+ examples = [['examples/' + example] for example in os.listdir('examples/')]
42
+
43
+ app = gr.Interface(
44
+ fn=predict_single_image,
45
+ inputs=gr.Image(type='pil'),
46
+ outputs=[
47
+ gr.Label(num_top_classes=len(class_names), label='Predictions'),
48
+ gr.Number(label='Prediction time (sec)'),
49
+ ],
50
+ examples=examples,
51
+ title=title,
52
+ description=description,
53
+ article=article,
54
+ )
55
+
56
+ app.launch(
57
+ share=True,
58
+ )
examples/example_1.jpg ADDED
examples/example_2.jpg ADDED
examples/example_3.jpg ADDED
examples/example_4.jpg ADDED
examples/example_5.jpg ADDED
examples/example_6.jpg ADDED
model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import torch.nn as nn
3
+
4
+ def create_vitb16_model(
5
+ num_classes: int,
6
+ ):
7
+ vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
8
+ vit_model = torchvision.models.vit_b_16(weights=vit_weights)
9
+ vit_transform = vit_weights.transforms()
10
+
11
+ for param in vit_model.parameters():
12
+ param.requires_grad = False
13
+
14
+ vit_model.heads = nn.Sequential(
15
+ nn.Linear(in_features=768, out_features=num_classes, bias=True),
16
+ )
17
+
18
+ return vit_model, vit_transform
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ gradio==3.17.1
vitb16_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a9ee27519da16ff37b67fe57efb0383d9633de799f588590355e43abb636241
3
+ size 343264069