alyxx commited on
Commit
732d0e7
1 Parent(s): e53ec2a

adding all files 062023 -kaiku

Browse files
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_vit_b_16_swag
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ class_names = ['apple_pie',
12
+ 'baby_back_ribs',
13
+ 'baklava',
14
+ 'beef_carpaccio',
15
+ 'beef_tartare',
16
+ 'beet_salad',
17
+ 'beignets',
18
+ 'bibimbap',
19
+ 'bread_pudding',
20
+ 'breakfast_burrito',
21
+ 'bruschetta',
22
+ 'caesar_salad',
23
+ 'cannoli',
24
+ 'caprese_salad',
25
+ 'carrot_cake',
26
+ 'ceviche',
27
+ 'cheese_plate',
28
+ 'cheesecake',
29
+ 'chicken_curry',
30
+ 'chicken_quesadilla',
31
+ 'chicken_wings',
32
+ 'chocolate_cake',
33
+ 'chocolate_mousse',
34
+ 'churros',
35
+ 'clam_chowder',
36
+ 'club_sandwich',
37
+ 'crab_cakes',
38
+ 'creme_brulee',
39
+ 'croque_madame',
40
+ 'cup_cakes',
41
+ 'deviled_eggs',
42
+ 'donuts',
43
+ 'dumplings',
44
+ 'edamame',
45
+ 'eggs_benedict',
46
+ 'escargots',
47
+ 'falafel',
48
+ 'filet_mignon',
49
+ 'fish_and_chips',
50
+ 'foie_gras',
51
+ 'french_fries',
52
+ 'french_onion_soup',
53
+ 'french_toast',
54
+ 'fried_calamari',
55
+ 'fried_rice',
56
+ 'frozen_yogurt',
57
+ 'garlic_bread',
58
+ 'gnocchi',
59
+ 'greek_salad',
60
+ 'grilled_cheese_sandwich',
61
+ 'grilled_salmon',
62
+ 'guacamole',
63
+ 'gyoza',
64
+ 'hamburger',
65
+ 'hot_and_sour_soup',
66
+ 'hot_dog',
67
+ 'huevos_rancheros',
68
+ 'hummus',
69
+ 'ice_cream',
70
+ 'lasagna',
71
+ 'lobster_bisque',
72
+ 'lobster_roll_sandwich',
73
+ 'macaroni_and_cheese',
74
+ 'macarons',
75
+ 'miso_soup',
76
+ 'mussels',
77
+ 'nachos',
78
+ 'omelette',
79
+ 'onion_rings',
80
+ 'oysters',
81
+ 'pad_thai',
82
+ 'paella',
83
+ 'pancakes',
84
+ 'panna_cotta',
85
+ 'peking_duck',
86
+ 'pho',
87
+ 'pizza',
88
+ 'pork_chop',
89
+ 'poutine',
90
+ 'prime_rib',
91
+ 'pulled_pork_sandwich',
92
+ 'ramen',
93
+ 'ravioli',
94
+ 'red_velvet_cake',
95
+ 'risotto',
96
+ 'samosa',
97
+ 'sashimi',
98
+ 'scallops',
99
+ 'seaweed_salad',
100
+ 'shrimp_and_grits',
101
+ 'spaghetti_bolognese',
102
+ 'spaghetti_carbonara',
103
+ 'spring_rolls',
104
+ 'steak',
105
+ 'strawberry_shortcake',
106
+ 'sushi',
107
+ 'tacos',
108
+ 'takoyaki',
109
+ 'tiramisu',
110
+ 'tuna_tartare',
111
+ 'waffles']
112
+
113
+ ### 2. Model and transforms preparation ###
114
+
115
+ # Create EffNetB0 model
116
+ vit_b_16_swag, vit_b_16_swag_transforms = create_vit_b_16_swag()
117
+
118
+ # Load saved weights
119
+ vit_b_16_swag.load_state_dict(
120
+ torch.load(
121
+ f="vit_b_16_swag_20percent_10epoch.pth",
122
+ map_location=torch.device("cpu"), # load to CPU
123
+ )
124
+ )
125
+
126
+
127
+ ### 3. Predict function ###
128
+
129
+ # Create predict function
130
+ def predict(img) -> Tuple[Dict, float]:
131
+ """Transforms and performs a prediction on img and returns prediction and time taken.
132
+ """
133
+ # Start the timer
134
+ start_time = timer()
135
+
136
+ # Transform the target image and add a batch dimension
137
+ img = vit_b_16_swag_transforms(img).unsqueeze(0)
138
+
139
+ # Put model into evaluation mode and turn on inference mode
140
+ vit_b_16_swag.eval()
141
+ with torch.inference_mode():
142
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
143
+ pred_probs = torch.softmax(vit_b_16_swag(img), dim=1)
144
+
145
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
146
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
147
+
148
+ # Calculate the prediction time
149
+ pred_time = round(timer() - start_time, 5)
150
+
151
+ # Return the prediction dictionary and prediction time
152
+ return pred_labels_and_probs, pred_time
153
+
154
+
155
+ ### 4. Gradio app ###
156
+
157
+ # Create title, description and article strings
158
+ title = "Food Classifier V1"
159
+ description = " 20 Percent Food 101 on Vit_b_16 SWAG"
160
+ article = "Created at google collab. Documentation at https://medium.com/me/stories/public, Code repository at https://github.com/Alyxx-The-Sniper/CNN "
161
+
162
+ # Create examples list from "examples/" directory
163
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
164
+
165
+ # Create the Gradio demo
166
+ demo = gr.Interface(fn=predict, # mapping function from input to output
167
+ inputs=gr.Image(type="pil"), # what are the inputs?
168
+ outputs=[gr.Label(num_top_classes=4, label="Predictions"), # what are the outputs?
169
+ gr.Number(label="Prediction time (s)")],
170
+ # our fn has two outputs, therefore we have two outputs
171
+ # Create examples list from "examples/" directory
172
+ examples=example_list,
173
+ title=title,
174
+ description=description,
175
+ article=article)
176
+
177
+ # Launch the demo!
178
+ demo.launch()
examples/1000431.jpg ADDED
examples/1005066.jpg ADDED
examples/1005649.jpg ADDED
model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ from torch import nn
3
+
4
+ def create_vit_b_16_swag(num_classes:int=101, seed:int=42):
5
+ # 1. Get the base mdoel with pretrained weights and send to target device
6
+ weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
7
+ transforms = weights.transforms()
8
+ model = torchvision.models.vit_b_16(weights=weights)#.to(device)
9
+
10
+ # 2. Freeze the base model layers
11
+ for param in model.parameters():
12
+ param.requires_grad = False
13
+
14
+
15
+ # 3. Change the heads
16
+ model.heads = nn.Linear(in_features=768,
17
+ out_features=101)#.to(device)
18
+
19
+ # 5. Give the model a name
20
+ model.name = "vit_b_16_swag"
21
+ print(f"[INFO] Created new {model.name} model.")
22
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4
4
+
5
+ output:
6
+ Writing demos/food_classification/requirements.txt
vit_b_16_swag_20percent_10epoch.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa2c4355cf30acf7454a3ee728c981de2e1913b9b1a7bfa2373f8bcbe40c0ecb
3
+ size 344736321