OnabajoMonsurat commited on
Commit
39dcf75
1 Parent(s): 712cc3a

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/04-pizza-dad.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import and class names setup
3
+ import gradio as gr
4
+ import os
5
+ import torch
6
+
7
+ from model import create_effnetb2_model
8
+ from timeit import default_timer as timer
9
+ from typings import Tuple, Dict
10
+ import class_names
11
+
12
+ # Setup class names
13
+ with open(class_names.txt, 'r') as f:
14
+ class_names= [food_name.strip() for food_name in f.readlines()]
15
+
16
+
17
+ # Model and transforms preparation
18
+ effnetb2_model, effnetb2_transform= create_effnetb2_model()
19
+ # Load state dict
20
+ effnetb2_model.load_state_dict(torch.load(
21
+ f= 'effnetb2_food101_model.pth',
22
+ map_location= torch.device('cpu')
23
+ )
24
+ )
25
+
26
+ # Predict function
27
+
28
+ def predict(img)-> Tuple[Dict, float]:
29
+ # start a timer
30
+ start_time= timer()
31
+
32
+ #transform the input image for use with effnet b2
33
+ transform_image= effnetb2_transform(img).unsqueeze(0)
34
+
35
+ #put model into eval mode, make pred
36
+ effnetb2_model.eval()
37
+ with torch.inference_mode():
38
+ pred_logits= effnetb2_model(transform_image)
39
+ pred_prob= torch.softmax(pred_logits, dim=1)
40
+
41
+ # create a pred label and pred prob dict
42
+ pred_label_and_prob= {class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))}
43
+
44
+
45
+ # calc pred time
46
+ stop_time= timer()
47
+ pred_time= round(stop_time - start_time, 4)
48
+
49
+
50
+ # return pred dict and pred time
51
+ return pred_label_and_prob, pred_time
52
+
53
+ # create example list
54
+ example_list= [['example/'+example] for example in os.listdir('example')]
55
+
56
+ # create gradio app
57
+ title= 'FoodVision Large 🍕🥩🍣 '
58
+ description= 'An EfficientnetB2 feature extractor Computer vision model to classify 101 classes of food from the food 101 image dataset'
59
+ article= 'Created at [09. PyTorch Model Deployment](https://www.learnpytorch.io/09_pytorch_model_deployment/).'
60
+
61
+ # Create the gradio demo
62
+ demo= gr.Interface(fn= predict,
63
+ inputs=gr.Image(type='pil'),
64
+ outputs= [gr.Label(num_top_classes=5, label= 'predictions'),
65
+ gr.Number(label= 'Prediction time (S)')],
66
+ examples= example_list,
67
+ title= title,
68
+ description= description,
69
+ article= article
70
+ )
71
+
72
+ # Launch the demo
73
+ demo.launch()
class_names.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheese_plate
18
+ cheesecake
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
effnetb2_food101_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bc414d33be77a0e4a1e0c50bd6c69f683ea53bdaac5e4ae3abad09e28499ce7
3
+ size 31833771
example/04-pizza-dad.jpg ADDED

Git LFS Details

  • SHA256: 0f00389758009e8430ca17c9a21ebb4564c6945e0c91c58cf058e6a93d267dc8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.87 MB
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchvision
4
+ from torch import nn
5
+
6
+ def create_effnetb2_model(num_classes:int=101,
7
+ seed:int=42):
8
+ # Create Effnet pretrained model
9
+ weights= torchvision.models.EfficientNet_B2_Weights.DEFAULT
10
+ transforms= weights.transforms()
11
+ model= torchvision.models.efficientnet_b2(weights=weights)
12
+
13
+ # Freeze all layers in the base model
14
+ for param in model.parameters():
15
+ param.requires_grad= False
16
+
17
+ # Change the classifier layer
18
+ torch.manual_seed(seed)
19
+ model.classifier= nn.Sequential(
20
+ nn.Dropout(p=0.3, inplace= True),
21
+ nn.Linear(in_features=1408, out_features= num_classes)
22
+ )
23
+
24
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ gradio==3.35.2