Spaces:
Runtime error
Runtime error
aaronherrera
commited on
Commit
•
06b1c65
1
Parent(s):
7ca342d
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
4 |
+
import openpyxl
|
5 |
+
|
6 |
+
#Function to predict the food from the image using the pre-trained model "nateraw/food"
|
7 |
+
def predict(image):
|
8 |
+
extractor = AutoFeatureExtractor.from_pretrained("nateraw/food")
|
9 |
+
model = AutoModelForImageClassification.from_pretrained("nateraw/food")
|
10 |
+
|
11 |
+
input = extractor(images=image, return_tensors='pt')
|
12 |
+
output = model(**input)
|
13 |
+
logits = output.logits
|
14 |
+
|
15 |
+
pred_class = logits.argmax(-1).item()
|
16 |
+
return(model.config.id2label[pred_class])
|
17 |
+
|
18 |
+
#Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA
|
19 |
+
def check_food(food, counter):
|
20 |
+
path = './database.xlsx'
|
21 |
+
wb_obj = openpyxl.load_workbook(path)
|
22 |
+
sheet_obj = wb_obj.active
|
23 |
+
|
24 |
+
foodPred, cal, carb, prot, fat = None, None, None, None, None
|
25 |
+
|
26 |
+
#Filter to prioritize the most probable match between the prediction and the entries in the database
|
27 |
+
for i in range(3, sheet_obj.max_row+1):
|
28 |
+
cell_obj = sheet_obj.cell(row = i, column = 2)
|
29 |
+
if counter == 0:
|
30 |
+
if len(food) >= 3:
|
31 |
+
foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + ","
|
32 |
+
elif len(food) == 2:
|
33 |
+
foodName = food[0].capitalize() + " " + food[1] + ","
|
34 |
+
elif len(food) == 1:
|
35 |
+
foodName = food[0].capitalize() + ","
|
36 |
+
condition = foodName == cell_obj.value[0:len(foodName):]
|
37 |
+
elif counter == 1:
|
38 |
+
if len(food) >= 3:
|
39 |
+
foodName = food[0].capitalize() + " " + food[1] + " " + food[2]
|
40 |
+
elif len(food) == 2:
|
41 |
+
foodName = food[0].capitalize() + " " + food[1]
|
42 |
+
elif len(food) == 1:
|
43 |
+
foodName = food[0].capitalize()
|
44 |
+
condition = foodName == cell_obj.value[0:len(foodName):]
|
45 |
+
elif counter == 2:
|
46 |
+
if len(food) >= 3:
|
47 |
+
foodName = food[0] + " " + food[1] + " " + food[2]
|
48 |
+
elif len(food) == 2:
|
49 |
+
foodName = food[0] + " " + food[1]
|
50 |
+
elif len(food) == 1:
|
51 |
+
foodName = food[0]
|
52 |
+
condition = foodName in cell_obj.value
|
53 |
+
elif (counter == 3) & (len(food) > 1):
|
54 |
+
condition = food[0] in cell_obj.value
|
55 |
+
else:
|
56 |
+
break
|
57 |
+
|
58 |
+
#Update values if conditions are met
|
59 |
+
if condition:
|
60 |
+
foodPred = cell_obj.value
|
61 |
+
cal = sheet_obj.cell(row = i, column = 5).value
|
62 |
+
carb = sheet_obj.cell(row = i, column = 7).value
|
63 |
+
prot = sheet_obj.cell(row = i, column = 6).value
|
64 |
+
fat = sheet_obj.cell(row = i, column = 10).value
|
65 |
+
break
|
66 |
+
|
67 |
+
return foodPred, cal, carb, prot, fat
|
68 |
+
|
69 |
+
#Function to prepare the output
|
70 |
+
def get_cc(food, weight):
|
71 |
+
|
72 |
+
#Configure the food string to match the entries in the database
|
73 |
+
food = food.split("_")
|
74 |
+
if food[-1][-1] == "s":
|
75 |
+
food[-1] = food[-1][:-1]
|
76 |
+
|
77 |
+
foodPred, cal, carb, prot, fat = None, None, None, None, None
|
78 |
+
counter = 0
|
79 |
+
|
80 |
+
#Try for the most probable match between the prediction and the entries in the database
|
81 |
+
while (not foodPred) & (counter <= 3):
|
82 |
+
foodPred, cal, carb, prot, fat = check_food(food,counter)
|
83 |
+
counter += 1
|
84 |
+
|
85 |
+
#Check if there is a match
|
86 |
+
if food:
|
87 |
+
output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g"
|
88 |
+
elif not food:
|
89 |
+
output = "No data for food"
|
90 |
+
|
91 |
+
return(output)
|
92 |
+
|
93 |
+
#Main function
|
94 |
+
def CC(image, weight):
|
95 |
+
pred = predict(image)
|
96 |
+
cc = get_cc(pred, weight)
|
97 |
+
return(pred, cc)
|
98 |
+
|
99 |
+
interface = gr.Interface(
|
100 |
+
fn = CC,
|
101 |
+
inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")],
|
102 |
+
outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')],
|
103 |
+
examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]])
|
104 |
+
|
105 |
+
interface.launch()
|