SagarPuniyani commited on
Commit
c981325
·
verified ·
1 Parent(s): 863e2bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -108
app.py CHANGED
@@ -1,109 +1,119 @@
1
- from PIL import Image
2
- from transformers import ViTFeatureExtractor, ViTForImageClassification
3
- import warnings
4
- import requests
5
- import gradio as gr
6
- from dotenv import load_dotenv
7
- import os
8
-
9
- load_dotenv() # Load .env file
10
- # API key for the nutrition information
11
- api_key =os.getenv("API_KEY")
12
-
13
-
14
- warnings.filterwarnings('ignore')
15
-
16
- # Load the pre-trained Vision Transformer model and feature extractor
17
- model_name = "google/vit-base-patch16-224"
18
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
19
- model = ViTForImageClassification.from_pretrained(model_name)
20
-
21
-
22
-
23
- def identify_image(image_path):
24
- """Identify the food item in the image."""
25
- image = Image.open(image_path)
26
- inputs = feature_extractor(images=image, return_tensors="pt")
27
- outputs = model(**inputs)
28
- logits = outputs.logits
29
- predicted_class_idx = logits.argmax(-1).item()
30
- predicted_label = model.config.id2label[predicted_class_idx]
31
- food_name = predicted_label.split(',')[0]
32
- return food_name
33
-
34
- def get_calories(food_name):
35
- """Get the calorie information of the identified food item."""
36
- api_url = 'https://api.api-ninjas.com/v1/nutrition?query={}'.format(food_name)
37
- response = requests.get(api_url, headers={'X-Api-Key': api_key})
38
- if response.status_code == requests.codes.ok:
39
- nutrition_info = response.json()
40
- else:
41
- nutrition_info = {"Error": response.status_code, "Message": response.text}
42
- return nutrition_info
43
-
44
- def format_nutrition_info(nutrition_info):
45
- """Format the nutritional information into an HTML table."""
46
- if "Error" in nutrition_info:
47
- return f"Error: {nutrition_info['Error']} - {nutrition_info['Message']}"
48
-
49
- if len(nutrition_info) == 0:
50
- return "No nutritional information found."
51
-
52
- nutrition_data = nutrition_info[0]
53
- table = f"""
54
- <table border="1" style="width: 100%; border-collapse: collapse;">
55
- <tr><th colspan="4" style="text-align: center;"><b>Nutrition Facts</b></th></tr>
56
- <tr><td colspan="4" style="text-align: center;"><b>Food Name: {nutrition_data['name']}</b></td></tr>
57
- <tr>
58
- <td style="text-align: left;"><b>Calories</b></td><td style="text-align: right;">{nutrition_data['calories']}</td>
59
- <td style="text-align: left;"><b>Serving Size (g)</b></td><td style="text-align: right;">{nutrition_data['serving_size_g']}</td>
60
- </tr>
61
- <tr>
62
- <td style="text-align: left;"><b>Total Fat (g)</b></td><td style="text-align: right;">{nutrition_data['fat_total_g']}</td>
63
- <td style="text-align: left;"><b>Saturated Fat (g)</b></td><td style="text-align: right;">{nutrition_data['fat_saturated_g']}</td>
64
- </tr>
65
- <tr>
66
- <td style="text-align: left;"><b>Protein (g)</b></td><td style="text-align: right;">{nutrition_data['protein_g']}</td>
67
- <td style="text-align: left;"><b>Sodium (mg)</b></td><td style="text-align: right;">{nutrition_data['sodium_mg']}</td>
68
- </tr>
69
- <tr>
70
- <td style="text-align: left;"><b>Potassium (mg)</b></td><td style="text-align: right;">{nutrition_data['potassium_mg']}</td>
71
- <td style="text-align: left;"><b>Cholesterol (mg)</b></td><td style="text-align: right;">{nutrition_data['cholesterol_mg']}</td>
72
- </tr>
73
- <tr>
74
- <td style="text-align: left;"><b>Total Carbohydrates (g)</b></td><td style="text-align: right;">{nutrition_data['carbohydrates_total_g']}</td>
75
- <td style="text-align: left;"><b>Fiber (g)</b></td><td style="text-align: right;">{nutrition_data['fiber_g']}</td>
76
- </tr>
77
- <tr>
78
- <td style="text-align: left;"><b>Sugar (g)</b></td><td style="text-align: right;">{nutrition_data['sugar_g']}</td>
79
- <td></td><td></td>
80
- </tr>
81
- </table>
82
- """
83
- return table
84
-
85
- def main_process(image_path):
86
- """Identify the food item and fetch its calorie information."""
87
- food_name = identify_image(image_path)
88
- nutrition_info = get_calories(food_name)
89
- formatted_nutrition_info = format_nutrition_info(nutrition_info)
90
- return formatted_nutrition_info
91
-
92
- # Define the Gradio interface
93
- def gradio_interface(image):
94
- formatted_nutrition_info = main_process(image)
95
- return formatted_nutrition_info
96
-
97
- # Create the Gradio UI
98
- iface = gr.Interface(
99
- fn=gradio_interface,
100
- inputs=gr.Image(type="filepath"),
101
- outputs="html",
102
- title="Food Identification and Nutrition Info",
103
- description="Upload an image of food to get nutritional information.",
104
- allow_flagging="never" # Disable flagging
105
- )
106
-
107
- # Launch the Gradio app
108
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
109
  iface.launch(share=True)
 
1
+ from PIL import Image
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
3
+ import warnings
4
+ import requests
5
+ import gradio as gr
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ load_dotenv() # Load .env file
10
+ # API key for the nutrition information
11
+ api_key =os.getenv("API_KEY")
12
+
13
+
14
+ warnings.filterwarnings('ignore')
15
+
16
+ # Load the pre-trained Vision Transformer model and feature extractor
17
+ model_name = "google/vit-base-patch16-224"
18
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
19
+ model = ViTForImageClassification.from_pretrained(model_name)
20
+
21
+
22
+ examples = [
23
+ ["data/apple.jpg"], # Path to example images
24
+ ["data/banana.jpg"],
25
+ ["data/pizza.png"],
26
+ ["data/burger.png"],
27
+ ["data/grapes.jpeg"],
28
+ ["data/fries.jpeg"],
29
+ ]
30
+
31
+
32
+ def identify_image(image_path):
33
+ """Identify the food item in the image."""
34
+ image = Image.open(image_path)
35
+ inputs = feature_extractor(images=image, return_tensors="pt")
36
+ outputs = model(**inputs)
37
+ logits = outputs.logits
38
+ predicted_class_idx = logits.argmax(-1).item()
39
+ predicted_label = model.config.id2label[predicted_class_idx]
40
+ food_name = predicted_label.split(',')[0]
41
+ return food_name
42
+
43
+ def get_calories(food_name):
44
+ """Get the calorie information of the identified food item."""
45
+ api_url = 'https://api.api-ninjas.com/v1/nutrition?query={}'.format(food_name)
46
+ response = requests.get(api_url, headers={'X-Api-Key': api_key})
47
+ if response.status_code == requests.codes.ok:
48
+ nutrition_info = response.json()
49
+ else:
50
+ nutrition_info = {"Error": response.status_code, "Message": response.text}
51
+ return nutrition_info
52
+
53
+ def format_nutrition_info(nutrition_info):
54
+ """Format the nutritional information into an HTML table."""
55
+ if "Error" in nutrition_info:
56
+ return f"Error: {nutrition_info['Error']} - {nutrition_info['Message']}"
57
+
58
+ if len(nutrition_info) == 0:
59
+ return "No nutritional information found."
60
+
61
+ nutrition_data = nutrition_info[0]
62
+ table = f"""
63
+ <table border="1" style="width: 100%; border-collapse: collapse;">
64
+ <tr><th colspan="4" style="text-align: center;"><b>Nutrition Facts</b></th></tr>
65
+ <tr><td colspan="4" style="text-align: center;"><b>Food Name: {nutrition_data['name']}</b></td></tr>
66
+ <tr>
67
+ <td style="text-align: left;"><b>Calories</b></td><td style="text-align: right;">{nutrition_data['calories']}</td>
68
+ <td style="text-align: left;"><b>Serving Size (g)</b></td><td style="text-align: right;">{nutrition_data['serving_size_g']}</td>
69
+ </tr>
70
+ <tr>
71
+ <td style="text-align: left;"><b>Total Fat (g)</b></td><td style="text-align: right;">{nutrition_data['fat_total_g']}</td>
72
+ <td style="text-align: left;"><b>Saturated Fat (g)</b></td><td style="text-align: right;">{nutrition_data['fat_saturated_g']}</td>
73
+ </tr>
74
+ <tr>
75
+ <td style="text-align: left;"><b>Protein (g)</b></td><td style="text-align: right;">{nutrition_data['protein_g']}</td>
76
+ <td style="text-align: left;"><b>Sodium (mg)</b></td><td style="text-align: right;">{nutrition_data['sodium_mg']}</td>
77
+ </tr>
78
+ <tr>
79
+ <td style="text-align: left;"><b>Potassium (mg)</b></td><td style="text-align: right;">{nutrition_data['potassium_mg']}</td>
80
+ <td style="text-align: left;"><b>Cholesterol (mg)</b></td><td style="text-align: right;">{nutrition_data['cholesterol_mg']}</td>
81
+ </tr>
82
+ <tr>
83
+ <td style="text-align: left;"><b>Total Carbohydrates (g)</b></td><td style="text-align: right;">{nutrition_data['carbohydrates_total_g']}</td>
84
+ <td style="text-align: left;"><b>Fiber (g)</b></td><td style="text-align: right;">{nutrition_data['fiber_g']}</td>
85
+ </tr>
86
+ <tr>
87
+ <td style="text-align: left;"><b>Sugar (g)</b></td><td style="text-align: right;">{nutrition_data['sugar_g']}</td>
88
+ <td></td><td></td>
89
+ </tr>
90
+ </table>
91
+ """
92
+ return table
93
+
94
+ def main_process(image_path):
95
+ """Identify the food item and fetch its calorie information."""
96
+ food_name = identify_image(image_path)
97
+ nutrition_info = get_calories(food_name)
98
+ formatted_nutrition_info = format_nutrition_info(nutrition_info)
99
+ return formatted_nutrition_info
100
+
101
+ # Define the Gradio interface
102
+ def gradio_interface(image):
103
+ formatted_nutrition_info = main_process(image)
104
+ return formatted_nutrition_info
105
+
106
+ # Create the Gradio UI
107
+ iface = gr.Interface(
108
+ fn=gradio_interface,
109
+ inputs=gr.Image(type="filepath"),
110
+ outputs="html",
111
+ title="Food Identification and Nutrition Info",
112
+ description="Upload an image of food to get nutritional information.",
113
+ examples=examples, # Add examples
114
+ allow_flagging="never" # Disable flagging
115
+ )
116
+
117
+ # Launch the Gradio app
118
+ if __name__ == "__main__":
119
  iface.launch(share=True)