Spaces:
Sleeping
Sleeping
"""Gradio demo app for Food-101 classification.""" | |
import sys | |
from pathlib import Path | |
from typing import Tuple, Dict, List | |
import time | |
import tempfile | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
# Add scripts directory to path | |
project_root = Path(__file__).parent.parent | |
sys.path.append(str(project_root / "scripts")) | |
from predict import Food101Predictor | |
from train import load_food101_splits | |
class GradioFood101App: | |
"""Gradio application for Food-101 classification.""" | |
def __init__(self): | |
"""Initialize the Gradio app with the ONNX predictor.""" | |
self.predictor = None | |
self.load_model() | |
def load_model(self): | |
"""Load the ONNX predictor.""" | |
try: | |
# Paths | |
model_path = project_root / "models/efficientnet_b0_food101.onnx" | |
data_dir = project_root / "food-101/food-101" | |
# Load class names | |
_, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=42) | |
class_names = [idx_to_class[i] for i in range(len(idx_to_class))] | |
# Initialize predictor | |
self.predictor = Food101Predictor(model_path, class_names) | |
print(f"[GRADIO] Model loaded successfully with {len(class_names)} classes") | |
except Exception as e: | |
print(f"[ERROR] Failed to load model: {e}") | |
raise | |
def predict_image(self, image: Image.Image, top_k: int = 5) -> Tuple[Dict, str]: | |
""" | |
Predict food class for uploaded image. | |
Args: | |
image: PIL Image | |
top_k: Number of top predictions | |
Returns: | |
(confidences_dict, info_text) | |
""" | |
if image is None: | |
return {}, "Please upload an image first!" | |
if self.predictor is None: | |
return {}, "Model not loaded. Please try again." | |
try: | |
# Save image temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
image.save(tmp_file.name) | |
temp_path = Path(tmp_file.name) | |
# Run prediction | |
start_time = time.time() | |
predictions, probabilities, inference_time = self.predictor.predict(temp_path, top_k) | |
total_time = (time.time() - start_time) * 1000 | |
# Clean up | |
temp_path.unlink(missing_ok=True) | |
# Format results for Gradio | |
confidences = {} | |
for pred, prob in zip(predictions, probabilities): | |
confidences[pred.replace('_', ' ').title()] = float(prob) | |
# Create info text | |
info_lines = [ | |
f"π **Prediction Results**", | |
f"β‘ **Inference Time**: {inference_time:.2f}ms", | |
f"π **Total Time**: {total_time:.2f}ms", | |
f"π§ **Model**: EfficientNet-B0 (ONNX)", | |
f"π **Top Prediction**: {predictions[0].replace('_', ' ').title()} ({probabilities[0]*100:.1f}%)" | |
] | |
info_text = "\n".join(info_lines) | |
return confidences, info_text | |
except Exception as e: | |
temp_path.unlink(missing_ok=True) | |
return {}, f"β **Error**: {str(e)}" | |
def get_examples(self) -> List[List]: | |
"""Get example images for the demo.""" | |
examples_dir = project_root / "food-101/food-101/images/examples" | |
examples = [] | |
# Get all example images | |
if examples_dir.exists(): | |
images = list(examples_dir.glob("*.jpg")) | |
for image_path in images: | |
# Format: [image_path, top_k_value] | |
examples.append([str(image_path), 5]) | |
# If no examples found, return empty list (Gradio will handle gracefully) | |
return examples if examples else [] | |
def create_interface(self) -> gr.Interface: | |
"""Create and return the Gradio interface.""" | |
# Custom CSS for better styling | |
css = """ | |
.main-header { | |
text-align: center; | |
background: linear-gradient(90deg, #ff6b6b, #4ecdc4); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
font-size: 2.5em; | |
font-weight: bold; | |
margin-bottom: 20px; | |
} | |
.info-box { | |
background-color: #f0f8ff; | |
border-left: 5px solid #4ecdc4; | |
padding: 15px; | |
margin: 10px 0; | |
border-radius: 5px; | |
} | |
""" | |
# Interface description | |
description = """ | |
## π Food-101 Image Classifier | |
Upload an image of food and get AI-powered predictions! This demo uses a fine-tuned **EfficientNet-B0** model | |
trained on the Food-101 dataset to classify 101 different types of food. | |
### π― **Model Performance** | |
- **Accuracy**: 84.49% on test set | |
- **Inference Speed**: ~7ms per image | |
- **Classes**: 101 different food types | |
### π **How to use** | |
1. Upload an image or try one of our examples | |
2. Adjust the number of top predictions (1-10) | |
3. Click Submit to get predictions with confidence scores! | |
""" | |
# Create the interface | |
interface = gr.Interface( | |
fn=self.predict_image, | |
inputs=[ | |
gr.Image( | |
type="pil", | |
label="πΈ Upload Food Image", | |
height=300 | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="π’ Number of Predictions" | |
) | |
], | |
outputs=[ | |
gr.Label( | |
label="π Predictions & Confidence Scores", | |
num_top_classes=10 | |
), | |
gr.Markdown( | |
label="π Prediction Details" | |
) | |
], | |
title="π Food-101 AI Classifier", | |
description=description, | |
examples=self.get_examples(), | |
css=css, | |
theme=gr.themes.Soft(), | |
flagging_mode="never" | |
) | |
return interface | |
def main(): | |
"""Main function to launch the Gradio app.""" | |
try: | |
# Initialize the app | |
print("[GRADIO] Initializing Food-101 Classifier App...") | |
app = GradioFood101App() | |
# Create interface | |
print("[GRADIO] Creating Gradio interface...") | |
interface = app.create_interface() | |
# Launch the app | |
print("[GRADIO] Launching app...") | |
interface.launch( | |
share=False, # Set to True to create public link | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |
except Exception as e: | |
print(f"[ERROR] Failed to launch Gradio app: {e}") | |
raise | |
if __name__ == "__main__": | |
main() |