balakrish181 commited on
Commit
d957918
1 Parent(s): d5f841e

first-commit

Browse files
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from model import model_classification
4
+
5
+ path = 'efficient_cat_dog.pth'
6
+ class_names = ['cat','dog']
7
+
8
+
9
+ model,transforms = model_classification()
10
+ model.load_state_dict(torch.load(path))
11
+
12
+
13
+ def predict(img):
14
+
15
+ img = transforms(img).unsqueeze(0)
16
+
17
+ model.eval()
18
+
19
+ with torch.inference_mode():
20
+ logits = model(img)
21
+
22
+ pred_probs = torch.softmax(logits,dim=1)
23
+
24
+ pred_label_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
25
+
26
+
27
+ return pred_label_and_probs
28
+
29
+
30
+ title = 'Cat and Dog classification'
31
+ description = 'An EfficientNetB0 feature extractor computert vision model to classify the cats and dogs'
32
+
33
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
34
+
35
+
36
+ demo = gr.Interface(fn=predict,
37
+ inputs=gr.Image(type='pil'),
38
+ outputs=gr.Label(num_top_classes=2,label='Predictions'),
39
+ title=title,
40
+ examples=example_list,
41
+ description=description,
42
+
43
+ )
44
+
45
+
46
+
47
+ demo.launch(share=True)
efficient_cat_dog.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0779d727b6e18eeb90ea4253776da7e3e4618f8ffc59d4481e128aef27984978
3
+ size 16339558
examples/000110.jpg ADDED
examples/000321.jpg ADDED
examples/000329.jpg ADDED
examples/000380.jpg ADDED
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ import torchvision
5
+ from torchvision import models
6
+
7
+
8
+ def model_classification():
9
+
10
+ weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
11
+
12
+ model = models.efficientnet_b0(weights=weights)
13
+ tranforms = models.EfficientNet_B0_Weights.DEFAULT.transforms()
14
+ model.classifier[1] = nn.Linear(1280,2)
15
+
16
+ for params in model.parameters():
17
+ params.requires_grad=False
18
+
19
+
20
+
21
+ return model,tranforms
22
+
23
+
24
+
25
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ torch==2.2.2
3
+ torchvision==0.17.2
4
+ gradio==4.31.5