Th3BossC commited on
Commit
15935e0
1 Parent(s): 82051e7

first commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ effnet_b2_food101.pth filter=lfs diff=lfs merge=lfs -text
36
+ examples/*.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # creating app.py
2
+
3
+ ########## imports ############
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models, transforms
7
+ import gradio as gr
8
+ from model import create_model
9
+ import PIL
10
+ from PIL import Image
11
+ import os
12
+ from pathlib import Path
13
+ ###############################
14
+
15
+ def predict(img):
16
+ effnetb2, transform = create_model()
17
+ class_names = []
18
+ with open('classes.txt', 'r') as f:
19
+ class_names = [foodname.strip() for foodname in f.readlines()]
20
+ effnetb2.load_state_dict(torch.load(f = 'effnet_b2_food101.pth', map_location = torch.device('cpu')))
21
+
22
+ img = transform(img).unsqueeze(0)
23
+
24
+ model.eval()
25
+ with torch.inference_mode():
26
+ pred_label = class_names[model(img).softmax(dim = 1).argmax(dim = 1)]
27
+ print(pred_label)
28
+ return pred_label
29
+
30
+ ###############################
31
+ title = 'FoodVision Project'
32
+ description = 'FoodVision is an image classification model based on EfficientNet_B2 which has been trained on a 101 different classes using the Food101 dataset'
33
+ ###############################
34
+
35
+ example_list = [['examples/' + example] for example in os.listdir('examples')] ##############
36
+
37
+ demo = gr.Interface(
38
+ fn = predict,
39
+ inputs = gr.Image(type = 'pil'),
40
+ outputs = [gr.Textbox()],
41
+ title = title,
42
+ description = description
43
+ )
44
+
45
+ demo.launch()
classes.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheesecake
18
+ cheese_plate
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
effnet_b2_food101.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df74811b66ab04f42cffa2661285e967aed5b4458f2342306f0b32c41abed259
3
+ size 31831029
examples/1987413.jpg ADDED

Git LFS Details

  • SHA256: 90c71e13b04bda3ff7511fd1c7ba2175cfd5ea125ffb0fc9da881c603b639200
  • Pointer size: 130 Bytes
  • Size of remote file: 50.1 kB
examples/2850023.jpg ADDED

Git LFS Details

  • SHA256: c5f45590b083917fcd040de643c65d9caf6413c923de47307cc957c5dcf60324
  • Pointer size: 130 Bytes
  • Size of remote file: 47.6 kB
examples/2955769.jpg ADDED

Git LFS Details

  • SHA256: f12af28fb141d6c69f6f1b7762031a1f648988eef2d5814ec16287fac5a3854b
  • Pointer size: 130 Bytes
  • Size of remote file: 43 kB
examples/3441447.jpg ADDED

Git LFS Details

  • SHA256: e6417bd53e1a76031eed027044062acfcb7066b814faca5014fdb6046eafec4b
  • Pointer size: 130 Bytes
  • Size of remote file: 52.2 kB
examples/551010.jpg ADDED

Git LFS Details

  • SHA256: 79d687070610f0d385445f6f43b964be5c20361e50c7f8147f5a176d0178e269
  • Pointer size: 130 Bytes
  • Size of remote file: 77.4 kB
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # creating model.py
2
+
3
+ ########## imports ############
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import models, transforms
7
+
8
+ ###############################
9
+
10
+ def create_model():
11
+ weights = models.EfficientNet_B2_Weights.DEFAULT
12
+ transform = weights.transforms()
13
+
14
+ model = models.efficientnet_b2(weights = weights)
15
+
16
+ for param in model.parameters():
17
+ param.requires_grad = False
18
+
19
+ model.classifier = nn.Sequential(
20
+ nn.Dropout(p = 0.3, inplace = True),
21
+ nn.Linear(in_features = 1408, out_features = 101, bias = True)
22
+ )
23
+
24
+
25
+ return model, transform
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4