Calorie_Counter / app.py
aaronherrera's picture
Upload app.py
be211af
import gradio as gr
from transformers import pipeline
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import openpyxl
#Function to predict the food from the image using the pre-trained model "nateraw/food"
def predict(image):
extractor = AutoFeatureExtractor.from_pretrained("nateraw/food")
model = AutoModelForImageClassification.from_pretrained("nateraw/food")
input = extractor(images=image, return_tensors='pt')
output = model(**input)
logits = output.logits
pred_class = logits.argmax(-1).item()
return(model.config.id2label[pred_class])
#Function to retrieve the Nutritional Value from database.xlsx which is downloaded from USDA
def check_food(food, counter):
path = './database.xlsx'
wb_obj = openpyxl.load_workbook(path)
sheet_obj = wb_obj.active
foodPred, cal, carb, prot, fat = None, None, None, None, None
#Filter to prioritize the most probable match between the prediction and the entries in the database
for i in range(3, sheet_obj.max_row+1):
cell_obj = sheet_obj.cell(row = i, column = 2)
if counter == 0:
if len(food) >= 3:
foodName = food[0].capitalize() + " " + food[1] + " " + food[2] + ","
elif len(food) == 2:
foodName = food[0].capitalize() + " " + food[1] + ","
elif len(food) == 1:
foodName = food[0].capitalize() + ","
condition = foodName == cell_obj.value[0:len(foodName):]
elif counter == 1:
if len(food) >= 3:
foodName = food[0].capitalize() + " " + food[1] + " " + food[2]
elif len(food) == 2:
foodName = food[0].capitalize() + " " + food[1]
elif len(food) == 1:
foodName = food[0].capitalize()
condition = foodName == cell_obj.value[0:len(foodName):]
elif counter == 2:
if len(food) >= 3:
foodName = food[0] + " " + food[1] + " " + food[2]
elif len(food) == 2:
foodName = food[0] + " " + food[1]
elif len(food) == 1:
foodName = food[0]
condition = foodName in cell_obj.value
elif (counter == 3) & (len(food) > 1):
condition = food[0].capitalize() == cell_obj.value[0:len(food[0]):]
elif (counter == 4) & (len(food) > 1):
condition = food[0] in cell_obj.value
else:
break
#Update values if conditions are met
if condition:
foodPred = cell_obj.value
cal = sheet_obj.cell(row = i, column = 5).value
carb = sheet_obj.cell(row = i, column = 7).value
prot = sheet_obj.cell(row = i, column = 6).value
fat = sheet_obj.cell(row = i, column = 10).value
break
return foodPred, cal, carb, prot, fat
#Function to prepare the output
def get_cc(food, weight):
#Configure the food string to match the entries in the database
food = food.split("_")
if food[-1][-1] == "s":
food[-1] = food[-1][:-1]
foodPred, cal, carb, prot, fat = None, None, None, None, None
counter = 0
#Try for the most probable match between the prediction and the entries in the database
while (not foodPred) & (counter <= 4):
foodPred, cal, carb, prot, fat = check_food(food,counter)
counter += 1
#Check if there is a match
if food:
output = foodPred + "\nCalories: " + str(round(cal * weight)/100) + " kJ\nCarbohydrate: " + str(round(carb * weight)/100) + " g\nProtein: " + str(round(prot * weight)/100) + " g\nTotal Fat: " + str(round(fat * weight)/100) + " g"
elif not food:
output = "No data for food"
return(output)
#Main function
def CC(image, weight):
pred = predict(image)
cc = get_cc(pred, weight)
return(pred, cc)
interface = gr.Interface(
fn = CC,
inputs = [gr.inputs.Image(shape=(224,224)), gr.inputs.Number(default = 100, label = "Weight in grams (g):")],
outputs = [gr.outputs.Textbox(label='Food Prediction:'), gr.outputs.Textbox(label='Nutritional Value:')],
examples = [["pizza.jpg", 107], ["spaghetti.jpg",205]])
interface.launch()