DiabeticOwl commited on
Commit
cb1d8c8
1 Parent(s): f650d0d

Initial Build.

Browse files
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from model import create_effnetb2_model
5
+ from pathlib import Path
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple, Dict
8
+
9
+ with open('class_names.txt', 'r') as f:
10
+ CLASS_NAMES = [l.strip() for l in f.readlines()]
11
+ NUM_CLASSES = len(CLASS_NAMES)
12
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ MODEL, MODEL_TRANSFORMS = create_effnetb2_model(NUM_CLASSES, load_st_dict=True)
14
+
15
+
16
+ def predict(img) -> Tuple[Dict, float]:
17
+ start = timer()
18
+ X = MODEL_TRANSFORMS(img).unsqueeze(dim=0).to(DEVICE)
19
+ MODEL.eval()
20
+ with torch.inference_mode():
21
+ y_logits = MODEL(X)
22
+ y_prob = torch.softmax(y_logits, dim=1)
23
+ # Float casting due to Gradio's assumption that numpy objects
24
+ # should be iterated. Running `tolist()` prior to this
25
+ # dictionary comprehension can also fix this behavior.
26
+ y_prob = y_prob.squeeze(dim=0).cpu().numpy()
27
+ pred = {c: float(prob) for c, prob in zip(CLASS_NAMES, y_prob)}
28
+ end = timer()
29
+ return pred, end - start
30
+
31
+
32
+ title = "FoodVision Big 🍖❓"
33
+ description = (
34
+ "An [EfficientNetB2](https://pytorch.org/vision/stable/models/generated/torchvision.models.efficientnet_b2.html) "
35
+ "feature extraction that can classify images of 101 classes of food images."
36
+ )
37
+ article = (
38
+ "Created by [DiabeticOwl](https://huggingface.co/DiabeticOwl). "
39
+ "Uses the [Food101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/)."
40
+ )
41
+ demo = gr.Interface(
42
+ fn=predict,
43
+ inputs=gr.Image(shape=(288, 288), type='pil'),
44
+ outputs=[gr.Label(num_top_classes=5, label='Predictions'),
45
+ gr.Number(label='Prediction run time (s)')],
46
+ examples=list(Path('examples').glob('*.jpg')),
47
+ title=title,
48
+ description=description,
49
+ article=article
50
+ )
51
+
52
+ if __name__ == '__main__':
53
+ demo.launch()
class_names.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheese_plate
18
+ cheesecake
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
examples/example_1.jpg ADDED
examples/example_2.jpg ADDED
examples/example_3.jpg ADDED
examples/example_4.jpg ADDED
examples/example_5.jpg ADDED
examples/example_6.jpg ADDED
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6929a68394c7aef94ef492460a179fb54f62ef3645ebe8f27e6ad464d6080c84
3
+ size 31850487
model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from pathlib import Path
4
+ from torch import nn
5
+ from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
6
+ from typing import Optional, Tuple
7
+
8
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
+
11
+ def create_effnetb2_model(
12
+ num_classes: int,
13
+ seed: Optional[int] = 42,
14
+ load_st_dict: Optional[bool] = False
15
+ ) -> Tuple[nn.Module, nn.Module]:
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ weights = EfficientNet_B2_Weights.DEFAULT
19
+ transforms = weights.transforms()
20
+ model = efficientnet_b2(weights=weights)
21
+
22
+ model.classifier = nn.Sequential(
23
+ nn.Dropout(p=0.3, inplace=True),
24
+ nn.Linear(in_features=1408, out_features=num_classes, bias=True)
25
+ ).to(DEVICE)
26
+ if load_st_dict:
27
+ st_dict = Path('model.pth')
28
+ model.load_state_dict(torch.load(st_dict, map_location=DEVICE))
29
+ for param in model.parameters():
30
+ param.requires_grad = False
31
+
32
+ return model.to(DEVICE), transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=3.33.1