|
import sys |
|
sys.path.append('.') |
|
import gradio as gr |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from models.experimental import attempt_load |
|
from utils.datasets import letterbox |
|
from utils.general import non_max_suppression, scale_coords |
|
from gradio_webrtc import WebRTC |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = attempt_load('/home/user/app/best.pt', map_location=device) |
|
model.eval() |
|
|
|
|
|
class_names = model.names |
|
|
|
|
|
nutritional_data = { |
|
"biryani": {"calories": 290, "protein": 7.0, "fat": 9.0, "carbs": 40.0}, |
|
"buttermilk": {"calories": 40, "protein": 3.0, "fat": 1.0, "carbs": 4.0}, |
|
"chole": {"calories": 150, "protein": 8.0, "fat": 6.0, "carbs": 20.0}, |
|
"chutney": {"calories": 50, "protein": 1.0, "fat": 3.0, "carbs": 4.0}, |
|
"curd": {"calories": 98, "protein": 11.0, "fat": 4.0, "carbs": 3.0}, |
|
"dal": {"calories": 120, "protein": 9.0, "fat": 3.0, "carbs": 15.0}, |
|
"dosa": {"calories": 133, "protein": 2.7, "fat": 3.7, "carbs": 22.0}, |
|
"dry-curry": {"calories": 180, "protein": 7.0, "fat": 8.0, "carbs": 18.0}, |
|
"egg": {"calories": 68, "protein": 6.0, "fat": 5.0, "carbs": 0.5}, |
|
"gulab-jamun": {"calories": 143, "protein": 2.0, "fat": 7.0, "carbs": 20.0}, |
|
"idli": {"calories": 58, "protein": 2.0, "fat": 0.4, "carbs": 12.0}, |
|
"omelette": {"calories": 154, "protein": 11.0, "fat": 11.0, "carbs": 1.0}, |
|
"pickle": {"calories": 30, "protein": 0.5, "fat": 1.5, "carbs": 4.0}, |
|
"poori": {"calories": 150, "protein": 3.0, "fat": 7.0, "carbs": 20.0}, |
|
"rice": {"calories": 130, "protein": 2.7, "fat": 0.3, "carbs": 28.0}, |
|
"roti": {"calories": 71, "protein": 2.7, "fat": 0.4, "carbs": 15.0}, |
|
"salad": {"calories": 35, "protein": 1.0, "fat": 1.0, "carbs": 6.0}, |
|
"sambhar": {"calories": 50, "protein": 2.5, "fat": 1.5, "carbs": 7.0}, |
|
"soft-drink": {"calories": 150, "protein": 0.0, "fat": 0.0, "carbs": 39.0}, |
|
"tandoori-chicken": {"calories": 270, "protein": 27.0, "fat": 15.0, "carbs": 3.0}, |
|
"vada": {"calories": 132, "protein": 2.6, "fat": 5.5, "carbs": 16.0}, |
|
"wet-curry": {"calories": 190, "protein": 8.0, "fat": 10.0, "carbs": 15.0} |
|
} |
|
|
|
|
|
def detect_and_overlay_nutrition(image, conf_threshold=0.25, iou_threshold=0.45): |
|
img = np.array(image) |
|
img_resized = letterbox(img, new_shape=640)[0] |
|
img_resized = img_resized[:, :, ::-1].transpose(2, 0, 1) |
|
img_resized = np.ascontiguousarray(img_resized) |
|
|
|
|
|
img_tensor = torch.from_numpy(img_resized).to(device).float() |
|
img_tensor /= 255.0 |
|
img_tensor = img_tensor.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
pred = model(img_tensor, augment=False)[0] |
|
pred = non_max_suppression(pred, conf_threshold, iou_threshold, agnostic=False) |
|
|
|
original_img = np.array(image) |
|
for det in pred: |
|
if len(det): |
|
det[:, :4] = scale_coords(img_tensor.shape[2:], det[:, :4], original_img.shape).round() |
|
for *xyxy, conf, cls in det: |
|
cls_name = class_names[int(cls)] |
|
nutrition = nutritional_data.get(cls_name, None) |
|
|
|
|
|
cv2.rectangle(original_img, (int(xyxy[0]), int(xyxy[1])), |
|
(int(xyxy[2]), int(xyxy[3])), (255, 0, 0), 2) |
|
|
|
if nutrition: |
|
label = f"{cls_name}: {conf:.2f}\nCalories: {nutrition['calories']} kcal\nProtein: {nutrition['protein']} g\nCarbs: {nutrition['carbs']} g\nFat: {nutrition['fat']} g" |
|
y_offset = int(xyxy[1]) - 10 |
|
padding = 10 |
|
for i, line in enumerate(label.split('\n')): |
|
text_width, text_height = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] |
|
cv2.rectangle(original_img, (int(xyxy[0]), y_offset - padding), |
|
(int(xyxy[0]) + text_width, y_offset + text_height + padding), (255, 255, 255), -1) |
|
cv2.putText(original_img, line, (int(xyxy[0]), y_offset + text_height), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2) |
|
y_offset += text_height + padding |
|
|
|
return original_img |
|
|
|
|
|
def detect_live(frame, conf_threshold=0.25, iou_threshold=0.45): |
|
try: |
|
|
|
nparr = np.frombuffer(base64.b64decode(frame), np.uint8) |
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
|
processed_img = detect_and_overlay_nutrition(img, conf_threshold, iou_threshold) |
|
return processed_img |
|
except Exception as e: |
|
print(f"Error processing frame: {e}") |
|
return None |
|
|
|
rtc_configuration = { |
|
"iceServers": [ |
|
{"urls": ["stun:stun.l.google.com:19302"]}, |
|
{ |
|
"urls": "turn:global.turn.twilio.com:3478?transport=udp", |
|
"username": "your_twilio_username", |
|
"credential": "your_twilio_credential" |
|
}, |
|
{ |
|
"urls": "turn:global.turn.twilio.com:3478?transport=tcp", |
|
"username": "your_twilio_username", |
|
"credential": "your_twilio_credential" |
|
} |
|
] |
|
} |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π² Real-Time Food Detection (Upload & Live)") |
|
|
|
with gr.Tab("πΈ Upload Image"): |
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", label="Upload Food Image") |
|
conf_threshold = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence Threshold") |
|
iou_threshold = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="IoU Threshold") |
|
image_output = gr.Image(type="pil", label="Detected Food with Nutrition") |
|
image_button = gr.Button("Detect") |
|
image_button.click(detect_and_overlay_nutrition, inputs=[image_input, conf_threshold, iou_threshold], outputs=image_output) |
|
|
|
with gr.Tab("π₯ Live Webcam Detection"): |
|
with gr.Row(): |
|
live_stream = WebRTC(label="Live Webcam", rtc_configuration=rtc_configuration) |
|
conf_slider = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence Threshold") |
|
iou_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="IoU Threshold") |
|
live_stream.stream(fn=detect_live, inputs=[live_stream, conf_slider, iou_slider], outputs=[live_stream], time_limit=10) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|