Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.append('.') # Force local imports | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from models.experimental import attempt_load | |
| from utils.datasets import letterbox | |
| from utils.general import non_max_suppression, scale_coords | |
| from gradio_webrtc import WebRTC | |
| # Load the YOLOv7 model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = attempt_load('/home/user/app/best.pt', map_location=device) # Load the model | |
| model.eval() | |
| # Class names from YOLO model | |
| class_names = model.names | |
| # Dummy nutritional data (modify with actual data) | |
| nutritional_data = { | |
| "biryani": {"calories": 290, "protein": 7.0, "fat": 9.0, "carbs": 40.0}, | |
| "buttermilk": {"calories": 40, "protein": 3.0, "fat": 1.0, "carbs": 4.0}, | |
| "chole": {"calories": 150, "protein": 8.0, "fat": 6.0, "carbs": 20.0}, | |
| "chutney": {"calories": 50, "protein": 1.0, "fat": 3.0, "carbs": 4.0}, | |
| "curd": {"calories": 98, "protein": 11.0, "fat": 4.0, "carbs": 3.0}, | |
| "dal": {"calories": 120, "protein": 9.0, "fat": 3.0, "carbs": 15.0}, | |
| "dosa": {"calories": 133, "protein": 2.7, "fat": 3.7, "carbs": 22.0}, | |
| "dry-curry": {"calories": 180, "protein": 7.0, "fat": 8.0, "carbs": 18.0}, | |
| "egg": {"calories": 68, "protein": 6.0, "fat": 5.0, "carbs": 0.5}, | |
| "gulab-jamun": {"calories": 143, "protein": 2.0, "fat": 7.0, "carbs": 20.0}, | |
| "idli": {"calories": 58, "protein": 2.0, "fat": 0.4, "carbs": 12.0}, | |
| "omelette": {"calories": 154, "protein": 11.0, "fat": 11.0, "carbs": 1.0}, | |
| "pickle": {"calories": 30, "protein": 0.5, "fat": 1.5, "carbs": 4.0}, | |
| "poori": {"calories": 150, "protein": 3.0, "fat": 7.0, "carbs": 20.0}, | |
| "rice": {"calories": 130, "protein": 2.7, "fat": 0.3, "carbs": 28.0}, | |
| "roti": {"calories": 71, "protein": 2.7, "fat": 0.4, "carbs": 15.0}, | |
| "salad": {"calories": 35, "protein": 1.0, "fat": 1.0, "carbs": 6.0}, | |
| "sambhar": {"calories": 50, "protein": 2.5, "fat": 1.5, "carbs": 7.0}, | |
| "soft-drink": {"calories": 150, "protein": 0.0, "fat": 0.0, "carbs": 39.0}, | |
| "tandoori-chicken": {"calories": 270, "protein": 27.0, "fat": 15.0, "carbs": 3.0}, | |
| "vada": {"calories": 132, "protein": 2.6, "fat": 5.5, "carbs": 16.0}, | |
| "wet-curry": {"calories": 190, "protein": 8.0, "fat": 10.0, "carbs": 15.0} | |
| } | |
| # Function to process images and overlay nutrition info | |
| def detect_and_overlay_nutrition(image, conf_threshold=0.25, iou_threshold=0.45): | |
| img = np.array(image) | |
| img_resized = letterbox(img, new_shape=640)[0] # Resize to 640x640 | |
| img_resized = img_resized[:, :, ::-1].transpose(2, 0, 1) # Convert BGR to RGB | |
| img_resized = np.ascontiguousarray(img_resized) | |
| # Convert to tensor | |
| img_tensor = torch.from_numpy(img_resized).to(device).float() | |
| img_tensor /= 255.0 # Normalize | |
| img_tensor = img_tensor.unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| pred = model(img_tensor, augment=False)[0] | |
| pred = non_max_suppression(pred, conf_threshold, iou_threshold, agnostic=False) | |
| original_img = np.array(image) | |
| for det in pred: | |
| if len(det): | |
| det[:, :4] = scale_coords(img_tensor.shape[2:], det[:, :4], original_img.shape).round() | |
| for *xyxy, conf, cls in det: | |
| cls_name = class_names[int(cls)] | |
| nutrition = nutritional_data.get(cls_name, None) | |
| # Draw bounding box and label | |
| cv2.rectangle(original_img, (int(xyxy[0]), int(xyxy[1])), | |
| (int(xyxy[2]), int(xyxy[3])), (255, 0, 0), 2) | |
| if nutrition: | |
| label = f"{cls_name}: {conf:.2f}\nCalories: {nutrition['calories']} kcal\nProtein: {nutrition['protein']} g\nCarbs: {nutrition['carbs']} g\nFat: {nutrition['fat']} g" | |
| y_offset = int(xyxy[1]) - 10 | |
| padding = 10 | |
| for i, line in enumerate(label.split('\n')): | |
| text_width, text_height = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] | |
| cv2.rectangle(original_img, (int(xyxy[0]), y_offset - padding), | |
| (int(xyxy[0]) + text_width, y_offset + text_height + padding), (255, 255, 255), -1) | |
| cv2.putText(original_img, line, (int(xyxy[0]), y_offset + text_height), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2) | |
| y_offset += text_height + padding | |
| return original_img | |
| # Function for WebRTC live detection | |
| def detect_live(frame, conf_threshold=0.25, iou_threshold=0.45): | |
| try: | |
| # Decode the base64 encoded frame | |
| nparr = np.frombuffer(base64.b64decode(frame), np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| processed_img = detect_and_overlay_nutrition(img, conf_threshold, iou_threshold) | |
| return processed_img | |
| except Exception as e: | |
| print(f"Error processing frame: {e}") | |
| return None | |
| rtc_configuration = { | |
| "iceServers": [ | |
| {"urls": ["stun:stun.l.google.com:19302"]}, # Free STUN server | |
| { | |
| "urls": "turn:global.turn.twilio.com:3478?transport=udp", # Twilio TURN server | |
| "username": "your_twilio_username", # Replace with your Twilio username | |
| "credential": "your_twilio_credential" # Replace with your Twilio credential | |
| }, | |
| { | |
| "urls": "turn:global.turn.twilio.com:3478?transport=tcp", # Twilio TURN server (TCP) | |
| "username": "your_twilio_username", # Replace with your Twilio username | |
| "credential": "your_twilio_credential" # Replace with your Twilio credential | |
| } | |
| ] | |
| } | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π² Real-Time Food Detection (Upload & Live)") | |
| with gr.Tab("πΈ Upload Image"): | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Food Image") | |
| conf_threshold = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence Threshold") | |
| iou_threshold = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="IoU Threshold") | |
| image_output = gr.Image(type="pil", label="Detected Food with Nutrition") | |
| image_button = gr.Button("Detect") | |
| image_button.click(detect_and_overlay_nutrition, inputs=[image_input, conf_threshold, iou_threshold], outputs=image_output) | |
| with gr.Tab("π₯ Live Webcam Detection"): | |
| with gr.Row(): | |
| live_stream = WebRTC(label="Live Webcam", rtc_configuration=rtc_configuration) | |
| conf_slider = gr.Slider(0.0, 1.0, value=0.25, step=0.05, label="Confidence Threshold") | |
| iou_slider = gr.Slider(0.0, 1.0, value=0.45, step=0.05, label="IoU Threshold") | |
| live_stream.stream(fn=detect_live, inputs=[live_stream, conf_slider, iou_slider], outputs=[live_stream], time_limit=10) | |
| if __name__ == "__main__": | |
| demo.launch() | |