harrispatil's picture
Update app.py
b2b3391 verified
import sys
sys.path.append('.') # Force local imports
import gradio as gr
import torch
import cv2
import numpy as np
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import non_max_suppression, scale_coords
from gradio_webrtc import WebRTC
# Load the YOLOv7 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = attempt_load('/home/user/app/best.pt', map_location=device) # Load the model
model.eval()
# Class names from YOLO model
class_names = model.names
# Dummy nutritional data (modify with actual data)
nutritional_data = {
"biryani": {"calories": 290, "protein": 7.0, "fat": 9.0, "carbs": 40.0},
"buttermilk": {"calories": 40, "protein": 3.0, "fat": 1.0, "carbs": 4.0},
"chole": {"calories": 150, "protein": 8.0, "fat": 6.0, "carbs": 20.0},
"chutney": {"calories": 50, "protein": 1.0, "fat": 3.0, "carbs": 4.0},
"curd": {"calories": 98, "protein": 11.0, "fat": 4.0, "carbs": 3.0},
"dal": {"calories": 120, "protein": 9.0, "fat": 3.0, "carbs": 15.0},
"dosa": {"calories": 133, "protein": 2.7, "fat": 3.7, "carbs": 22.0},
"dry-curry": {"calories": 180, "protein": 7.0, "fat": 8.0, "carbs": 18.0},
"egg": {"calories": 68, "protein": 6.0, "fat": 5.0, "carbs": 0.5},
"gulab-jamun": {"calories": 143, "protein": 2.0, "fat": 7.0, "carbs": 20.0},
"idli": {"calories": 58, "protein": 2.0, "fat": 0.4, "carbs": 12.0},
"omelette": {"calories": 154, "protein": 11.0, "fat": 11.0, "carbs": 1.0},
"pickle": {"calories": 30, "protein": 0.5, "fat": 1.5, "carbs": 4.0},
"poori": {"calories": 150, "protein": 3.0, "fat": 7.0, "carbs": 20.0},
"rice": {"calories": 130, "protein": 2.7, "fat": 0.3, "carbs": 28.0},
"roti": {"calories": 71, "protein": 2.7, "fat": 0.4, "carbs": 15.0},
"salad": {"calories": 35, "protein": 1.0, "fat": 1.0, "carbs": 6.0},
"sambhar": {"calories": 50, "protein": 2.5, "fat": 1.5, "carbs": 7.0},
"soft-drink": {"calories": 150, "protein": 0.0, "fat": 0.0, "carbs": 39.0},
"tandoori-chicken": {"calories": 270, "protein": 27.0, "fat": 15.0, "carbs": 3.0},
"vada": {"calories": 132, "protein": 2.6, "fat": 5.5, "carbs": 16.0},
"wet-curry": {"calories": 190, "protein": 8.0, "fat": 10.0, "carbs": 15.0}
}
# Function to process images and overlay nutrition info
def detect_and_overlay_nutrition(image, conf_threshold=0.25, iou_threshold=0.45):
img = np.array(image)
img_resized = letterbox(img, new_shape=640)[0] # Resize to 640x640
img_resized = img_resized[:, :, ::-1].transpose(2, 0, 1) # Convert BGR to RGB
img_resized = np.ascontiguousarray(img_resized)
# Convert to tensor
img_tensor = torch.from_numpy(img_resized).to(device).float()
img_tensor /= 255.0 # Normalize
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
with torch.no_grad():
pred = model(img_tensor, augment=False)[0]
pred = non_max_suppression(pred, conf_threshold, iou_threshold, agnostic=False)
original_img = np.array(image)
for det in pred:
if len(det):
det[:, :4] = scale_coords(img_tensor.shape[2:], det[:, :4], original_img.shape).round()
for *xyxy, conf, cls in det:
cls_name = class_names[int(cls)]
nutrition = nutritional_data.get(cls_name, None)
# Draw bounding box and label
cv2.rectangle(original_img, (int(xyxy[0]), int(xyxy[1])),
(int(xyxy[2]), int(xyxy[3])), (255, 0, 0), 2)
if nutrition:
label = f"{cls_name}: {conf:.2f}\nCalories: {nutrition['calories']} kcal\nProtein: {nutrition['protein']} g\nCarbs: {nutrition['carbs']} g\nFat: {nutrition['fat']} g"
y_offset = int(xyxy[1]) - 10
padding = 10
for i, line in enumerate(label.split('\n')):
text_width, text_height = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
cv2.rectangle(original_img, (int(xyxy[0]), y_offset - padding),
(int(xyxy[0]) + text_width, y_offset + text_height + padding), (255, 255, 255), -1)
cv2.putText(original_img, line, (int(xyxy[0]), y_offset + text_height),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
y_offset += text_height + padding
return original_img
# Function for WebRTC live detection
def detect_live(frame, conf_threshold=0.25, iou_threshold=0.45):
try:
# Decode the base64 encoded frame
nparr = np.frombuffer(base64.b64decode(frame), np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
processed_img = detect_and_overlay_nutrition(img, conf_threshold, iou_threshold)
return processed_img
except Exception as e:
print(f"Error processing frame: {e}")
return None
rtc_configuration = {
"iceServers": [
{"urls": ["stun:stun.l.google.com:19302"]}, # Free STUN server
{
"urls": "turn:global.turn.twilio.com:3478?transport=udp", # Twilio TURN server
"username": "your_twilio_username", # Replace with your Twilio username
"credential": "your_twilio_credential" # Replace with your Twilio credential
},
{
"urls": "turn:global.turn.twilio.com:3478?transport=tcp", # Twilio TURN server (TCP)
"username": "your_twilio_username", # Replace with your Twilio username
"credential": "your_twilio_credential" # Replace with your Twilio credential
}
]
}
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# 🍲 Real-Time Food Detection (Upload & Live)")
with gr.Tab("πŸ“Έ Upload Image"):
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Food Image")
conf_threshold = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence Threshold")
iou_threshold = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="IoU Threshold")
image_output = gr.Image(type="pil", label="Detected Food with Nutrition")
image_button = gr.Button("Detect")
image_button.click(detect_and_overlay_nutrition, inputs=[image_input, conf_threshold, iou_threshold], outputs=image_output)
with gr.Tab("πŸŽ₯ Live Webcam Detection"):
with gr.Row():
live_stream = WebRTC(label="Live Webcam", rtc_configuration=rtc_configuration)
conf_slider = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence Threshold")
iou_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="IoU Threshold")
live_stream.stream(fn=detect_live, inputs=[live_stream, conf_slider, iou_slider], outputs=[live_stream], time_limit=10)
if __name__ == "__main__":
demo.launch()