ramirjf commited on
Commit
9d02e76
1 Parent(s): c506e43

initial commit

Browse files
VIT_32_20_003.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b9cc08c3955685db4bc561ab448906e0a738e58b6c199a652bea705824cb35f
3
+ size 343564394
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import createVITModel
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ with open("classes.txt", "r") as f: # reading them in from class_names.txt
12
+ class_names = [food_name.strip() for food_name in f.readlines()]
13
+
14
+ model, vit_transform = createVITModel()
15
+ model.load_state_dict(torch.load('VIT_32_20_003.pth'))
16
+ model = model.to('cpu')
17
+
18
+ def predict(img) -> Tuple[Dict, float]:
19
+ """Transforms and performs a prediction on img and returns prediction and time taken.
20
+ """
21
+ # Start the timer
22
+ start_time = timer()
23
+
24
+ # Transform the target image and add a batch dimension
25
+ img = vit_transform(img).unsqueeze(dim=0)
26
+
27
+ # Put model into evaluation mode and turn on inference mode
28
+ model.eval()
29
+ with torch.inference_mode():
30
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
31
+ pred_probs = torch.softmax(model(img), dim=1)
32
+
33
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
34
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
35
+
36
+ # Calculate the prediction time
37
+ pred_time = round(timer() - start_time, 5)
38
+
39
+ # Return the prediction dictionary and prediction time
40
+ return pred_labels_and_probs, pred_time
41
+
42
+
43
+ # Create title, description and article strings
44
+ title = "Food Image Classifier 🍰 🎂"
45
+ description = "A VIT Food Classifier."
46
+ article = ""
47
+
48
+ # Create the Gradio demo
49
+ demo = gr.Interface(fn=predict, # mapping function from input to output
50
+ inputs=gr.Image(type="pil"), # what are the inputs?
51
+ outputs=[gr.Label(num_top_classes=5, label="Predictions"), # what are the outputs?
52
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
53
+ examples=example_list,
54
+ title=title,
55
+ description=description,
56
+ article=article)
57
+
58
+ # Launch the demo!
59
+ demo.launch(debug=False) # generate a publically shareable URL?
classes.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheese_plate
18
+ cheesecake
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
examples/cookies.jpg ADDED
examples/cupcake.jpg ADDED
examples/flan.jpg ADDED
examples/mochi.jpg ADDED
examples/steak.jpg ADDED
model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+
5
+ def createVITModel(out_features: int) -> nn.Module:
6
+ # 1. Get pretrained weights for ViT-Base
7
+ pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
8
+ # 2. Setup a ViT model instance with pretrained weights
9
+ pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
10
+ # 3. Freeze the base parameters
11
+ for parameter in pretrained_vit.parameters():
12
+ parameter.requires_grad = False
13
+ # 4. Change the classifier head (set the seeds to ensure same initialization with linear head)
14
+ pretrained_vit.heads = nn.Linear(in_features=768, out_features=out_features).to('cpu')
15
+ vit_transforms = pretrained_vit_weights.transforms()
16
+ return pretrained_vit, vit_transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.1.2
2
+ torchvision==0.16.2
3
+ gradio==4.24.0