versatile153 commited on
Commit
deed510
·
verified ·
1 Parent(s): 157b42e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -62
app.py CHANGED
@@ -3,15 +3,14 @@ import base64
3
  from io import BytesIO
4
  from PIL import Image
5
  from fastapi import FastAPI, HTTPException
6
- import requests
7
  from transformers import pipeline
8
  from ultralytics import YOLO
9
  import uvicorn
10
- import pydantic
11
  import gradio as gr
12
  import threading
13
  import logging
14
- import numpy as np
15
 
16
  # ==============================
17
  # Logging
@@ -25,27 +24,19 @@ logger = logging.getLogger(__name__)
25
  food_classifier = pipeline("image-classification", model="nateraw/food")
26
  yolo_model = YOLO("yolov8n.pt")
27
 
28
- # ==============================
29
- # USDA API configuration
30
- # ==============================
31
  USDA_API_URL = "https://api.nal.usda.gov/fdc/v1/foods/search"
32
  USDA_API_KEY = os.getenv("USDA_API_KEY", "qktfia6caeuBSww2A5SYns8NaLlE2OuozHaEASzw")
33
 
34
- # ==============================
35
  # FastAPI app
36
- # ==============================
37
  app = FastAPI()
38
 
39
- # ==============================
40
- # Pydantic model for request body
41
- # ==============================
42
- class ImageRequest(pydantic.BaseModel):
43
- image: str # Base64-encoded image
44
- portion_size: float = 100.0 # Default portion size in grams
45
 
46
- # ==============================
47
- # Helper: decode base64 image
48
- # ==============================
49
  def decode_base64_image(base64_string):
50
  try:
51
  img_data = base64.b64decode(base64_string)
@@ -55,50 +46,37 @@ def decode_base64_image(base64_string):
55
  logger.error(f"Image decoding failed: {str(e)}")
56
  raise HTTPException(status_code=400, detail="Invalid base64 image")
57
 
58
- # ==============================
59
- # Helper: crop image around detected food/container
60
- # ==============================
61
- def crop_image_to_food(img, yolo_results, food_labels=["chicken_curry", "pizza", "salad", "lasagna", "risotto"], container_labels=["bowl", "plate", "dish"]):
62
  try:
63
- # Prioritize food labels first
64
  for result in yolo_results:
65
  for box, cls in zip(result.boxes.xyxy, result.boxes.cls):
66
  label = result.names[int(cls)]
67
  if label in food_labels:
68
  x1, y1, x2, y2 = map(int, box)
69
- cropped_img = img.crop((x1, y1, x2, y2))
70
- logger.info(f"Cropped image to food {label} at coordinates: ({x1}, {y1}, {x2}, {y2})")
71
- return cropped_img, True
72
- # Fallback to container labels
73
  for result in yolo_results:
74
  for box, cls in zip(result.boxes.xyxy, result.boxes.cls):
75
  label = result.names[int(cls)]
76
  if label in container_labels:
77
  x1, y1, x2, y2 = map(int, box)
78
- cropped_img = img.crop((x1, y1, x2, y2))
79
- logger.info(f"Cropped image to container {label} at coordinates: ({x1}, {y1}, {x2}, {y2})")
80
- return cropped_img, True
81
- logger.info("No food or container detected for cropping")
82
  return img, False
83
  except Exception as e:
84
  logger.error(f"Cropping failed: {str(e)}")
85
  return img, False
86
 
87
- # ==============================
88
- # Helper: nutrient calculation with USDA API
89
- # ==============================
90
  def calculate_nutrients(food_items, portion_size):
91
  nutrients = {"protein": 0, "carbs": 0, "fat": 0, "fiber": 0, "sodium": 0}
92
  micronutrients = {"vitamin_c": 0, "calcium": 0, "iron": 0}
93
  top_food = max(food_items, key=food_items.get, default=None)
94
  if not top_food:
95
- logger.warning("No food items detected for nutrient calculation")
96
  return nutrients, micronutrients, 0
97
 
98
- # Replace underscores with spaces for USDA API query
99
  query_food = top_food.replace("_", " ")
100
- logger.info(f"Querying USDA API for: {query_food}")
101
-
102
  try:
103
  response = requests.get(USDA_API_URL, params={
104
  "api_key": USDA_API_KEY,
@@ -107,10 +85,7 @@ def calculate_nutrients(food_items, portion_size):
107
  })
108
  response.raise_for_status()
109
  data = response.json()
110
- logger.debug(f"USDA API response: {data}")
111
-
112
  if not data.get("foods"):
113
- logger.warning(f"No food data found for {query_food}")
114
  return nutrients, micronutrients, 0
115
 
116
  food_data = data["foods"][0]
@@ -129,45 +104,41 @@ def calculate_nutrients(food_items, portion_size):
129
  "iron": food_nutrients.get("Iron, Fe", 0) * (portion_size / 100),
130
  }
131
  calories = (nutrients["protein"] * 4) + (nutrients["carbs"] * 4) + (nutrients["fat"] * 9)
132
- logger.info(f"Nutrients calculated: {nutrients}, Calories: {calories}")
133
  return nutrients, micronutrients, calories
134
- except requests.exceptions.RequestException as e:
135
  logger.error(f"USDA API request failed: {str(e)}")
136
  return nutrients, micronutrients, 0
137
 
138
  # ==============================
139
- # FastAPI endpoint for food analysis
140
  # ==============================
141
  @app.post("/analyze_food")
142
  async def analyze_food(request: ImageRequest):
143
  try:
144
- # Decode image
145
  img = decode_base64_image(request.image)
146
-
147
- # Run YOLO to detect objects and crop to food or container
148
  yolo_results = yolo_model(img)
149
  cropped_img, was_cropped = crop_image_to_food(img, yolo_results)
150
 
151
- # Food classification on cropped image
152
  food_results = food_classifier(cropped_img)
153
  food_items = {r["label"]: r["score"] for r in food_results if r["score"] >= 0.3}
154
- logger.info(f"Food items detected: {food_items}")
155
 
156
- # Non-food detection (YOLO objects not in food_items)
157
- non_food_items = [r.names[int(cls)] for r in yolo_results for cls in r.boxes.cls if r.names[int(cls)] not in food_items]
158
-
159
- # FIX: Decide non-food properly
160
- if any(score > 0.3 for score in food_items.values()):
161
- is_non_food = False
162
- else:
163
- is_non_food = True
 
 
 
164
 
165
- logger.info(f"Non-food items: {non_food_items}, is_non_food: {is_non_food}")
166
 
167
- # Nutrient analysis
168
  nutrients, micronutrients, calories = calculate_nutrients(food_items, request.portion_size)
169
 
170
- # Simplified ingredient inference
171
  ingredient_map = {
172
  "pizza": ["dough", "tomato sauce", "cheese"],
173
  "salad": ["lettuce", "tomato", "cucumber"],
@@ -221,9 +192,6 @@ iface = gr.Interface(
221
  description="Upload an image to analyze food items, non-food items, and nutritional content."
222
  )
223
 
224
- # ==============================
225
- # Run both FastAPI + Gradio
226
- # ==============================
227
  if __name__ == "__main__":
228
  threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=8000)).start()
229
  iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
3
  from io import BytesIO
4
  from PIL import Image
5
  from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
  from transformers import pipeline
8
  from ultralytics import YOLO
9
  import uvicorn
 
10
  import gradio as gr
11
  import threading
12
  import logging
13
+ import requests
14
 
15
  # ==============================
16
  # Logging
 
24
  food_classifier = pipeline("image-classification", model="nateraw/food")
25
  yolo_model = YOLO("yolov8n.pt")
26
 
27
+ # USDA API config
 
 
28
  USDA_API_URL = "https://api.nal.usda.gov/fdc/v1/foods/search"
29
  USDA_API_KEY = os.getenv("USDA_API_KEY", "qktfia6caeuBSww2A5SYns8NaLlE2OuozHaEASzw")
30
 
 
31
  # FastAPI app
 
32
  app = FastAPI()
33
 
34
+ # Request schema
35
+ class ImageRequest(BaseModel):
36
+ image: str # Base64 image
37
+ portion_size: float = 100.0
 
 
38
 
39
+ # Decode base64 image
 
 
40
  def decode_base64_image(base64_string):
41
  try:
42
  img_data = base64.b64decode(base64_string)
 
46
  logger.error(f"Image decoding failed: {str(e)}")
47
  raise HTTPException(status_code=400, detail="Invalid base64 image")
48
 
49
+ # Crop image to food or container
50
+ def crop_image_to_food(img, yolo_results,
51
+ food_labels=["chicken_curry", "pizza", "salad", "lasagna", "risotto"],
52
+ container_labels=["bowl", "plate", "dish"]):
53
  try:
 
54
  for result in yolo_results:
55
  for box, cls in zip(result.boxes.xyxy, result.boxes.cls):
56
  label = result.names[int(cls)]
57
  if label in food_labels:
58
  x1, y1, x2, y2 = map(int, box)
59
+ return img.crop((x1, y1, x2, y2)), True
 
 
 
60
  for result in yolo_results:
61
  for box, cls in zip(result.boxes.xyxy, result.boxes.cls):
62
  label = result.names[int(cls)]
63
  if label in container_labels:
64
  x1, y1, x2, y2 = map(int, box)
65
+ return img.crop((x1, y1, x2, y2)), True
 
 
 
66
  return img, False
67
  except Exception as e:
68
  logger.error(f"Cropping failed: {str(e)}")
69
  return img, False
70
 
71
+ # Calculate nutrients
 
 
72
  def calculate_nutrients(food_items, portion_size):
73
  nutrients = {"protein": 0, "carbs": 0, "fat": 0, "fiber": 0, "sodium": 0}
74
  micronutrients = {"vitamin_c": 0, "calcium": 0, "iron": 0}
75
  top_food = max(food_items, key=food_items.get, default=None)
76
  if not top_food:
 
77
  return nutrients, micronutrients, 0
78
 
 
79
  query_food = top_food.replace("_", " ")
 
 
80
  try:
81
  response = requests.get(USDA_API_URL, params={
82
  "api_key": USDA_API_KEY,
 
85
  })
86
  response.raise_for_status()
87
  data = response.json()
 
 
88
  if not data.get("foods"):
 
89
  return nutrients, micronutrients, 0
90
 
91
  food_data = data["foods"][0]
 
104
  "iron": food_nutrients.get("Iron, Fe", 0) * (portion_size / 100),
105
  }
106
  calories = (nutrients["protein"] * 4) + (nutrients["carbs"] * 4) + (nutrients["fat"] * 9)
 
107
  return nutrients, micronutrients, calories
108
+ except Exception as e:
109
  logger.error(f"USDA API request failed: {str(e)}")
110
  return nutrients, micronutrients, 0
111
 
112
  # ==============================
113
+ # FastAPI endpoint
114
  # ==============================
115
  @app.post("/analyze_food")
116
  async def analyze_food(request: ImageRequest):
117
  try:
 
118
  img = decode_base64_image(request.image)
 
 
119
  yolo_results = yolo_model(img)
120
  cropped_img, was_cropped = crop_image_to_food(img, yolo_results)
121
 
122
+ # Food classification
123
  food_results = food_classifier(cropped_img)
124
  food_items = {r["label"]: r["score"] for r in food_results if r["score"] >= 0.3}
 
125
 
126
+ # Fix: whitelist food labels so they don’t go into non_food_items
127
+ food_label_whitelist = [
128
+ "pizza", "salad", "chicken", "chicken_wings", "shrimp_and_grits",
129
+ "lasagna", "risotto", "burger", "sandwich", "pasta"
130
+ ]
131
+ non_food_items = [
132
+ r.names[int(cls)]
133
+ for r in yolo_results
134
+ for cls in r.boxes.cls
135
+ if r.names[int(cls)] not in food_items and r.names[int(cls)] not in food_label_whitelist
136
+ ]
137
 
138
+ is_non_food = len(non_food_items) > len(food_items) and max(food_items.values(), default=0) < 0.5
139
 
 
140
  nutrients, micronutrients, calories = calculate_nutrients(food_items, request.portion_size)
141
 
 
142
  ingredient_map = {
143
  "pizza": ["dough", "tomato sauce", "cheese"],
144
  "salad": ["lettuce", "tomato", "cucumber"],
 
192
  description="Upload an image to analyze food items, non-food items, and nutritional content."
193
  )
194
 
 
 
 
195
  if __name__ == "__main__":
196
  threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=8000)).start()
197
  iface.launch(server_name="0.0.0.0", server_port=7860, share=True)