LeooNic commited on
Commit
b29fbac
·
1 Parent(s): b6c74eb

Deploy Food-101 classifier with 84.49% accuracy

Browse files

- EfficientNet-B0 model achieving 84.49% test accuracy
- ONNX optimized for 7ms inference time
- Interactive Gradio interface with example images
- Supports 101 food classes from Food-101 dataset
- Ready for production deployment

README.md CHANGED
@@ -1,14 +1,122 @@
1
  ---
2
- title: Food 101 Classifier
3
- emoji: 👁
4
- colorFrom: red
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.46.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: My Space
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Food-101 AI Classifier
3
+ emoji: 🍔
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
+ # 🍔 Food-101 AI Classifier
14
+
15
+ An AI-powered food image classifier trained on the Food-101 dataset, capable of recognizing 101 different types of food with high accuracy.
16
+
17
+ ## 🎯 Model Performance
18
+
19
+ - **Architecture**: EfficientNet-B0 (fine-tuned)
20
+ - **Test Accuracy**: 84.49%
21
+ - **Top-5 Accuracy**: 96.72%
22
+ - **Inference Speed**: ~7ms per image
23
+ - **Model Size**: 15.77 MB (ONNX optimized)
24
+
25
+ ## 🚀 Features
26
+
27
+ - **Fast Inference**: ONNX-optimized model for lightning-fast predictions
28
+ - **High Accuracy**: State-of-the-art performance on Food-101 dataset
29
+ - **User-Friendly Interface**: Clean and intuitive Gradio web interface
30
+ - **Real-time Predictions**: Upload any food image and get instant results
31
+ - **Confidence Scores**: See how confident the model is about each prediction
32
+
33
+ ## 📊 Dataset
34
+
35
+ This model was trained on the [Food-101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/), which contains:
36
+ - **101 food categories**
37
+ - **1,000 images per category**
38
+ - **Total**: 101,000 images
39
+
40
+ ## 🏆 Recognized Food Categories
41
+
42
+ The model can classify the following 101 food types:
43
+
44
+ ```
45
+ apple_pie, baby_back_ribs, baklava, beef_carpaccio, beef_tartare, beet_salad,
46
+ beignets, bibimbap, bread_pudding, breakfast_burrito, bruschetta, caesar_salad,
47
+ cannoli, caprese_salad, carrot_cake, ceviche, cheese_plate, cheesecake,
48
+ chicken_curry, chicken_quesadilla, chicken_wings, chocolate_cake, chocolate_mousse,
49
+ churros, clam_chowder, club_sandwich, crab_cakes, creme_brulee, croque_madame,
50
+ cup_cakes, deviled_eggs, donuts, dumplings, edamame, eggs_benedict, escargots,
51
+ falafel, filet_mignon, fish_and_chips, foie_gras, french_fries, french_onion_soup,
52
+ french_toast, fried_calamari, fried_rice, frozen_yogurt, garlic_bread, gnocchi,
53
+ greek_salad, grilled_cheese_sandwich, grilled_salmon, guacamole, gyoza, hamburger,
54
+ hot_and_sour_soup, hot_dog, huevos_rancheros, hummus, ice_cream, lasagna,
55
+ lobster_bisque, lobster_roll_sandwich, macaroni_and_cheese, macarons, miso_soup,
56
+ mussels, nachos, omelette, onion_rings, oysters, pad_thai, paella, pancakes,
57
+ panna_cotta, peking_duck, pho, pizza, pork_chop, poutine, prime_rib, pulled_pork_sandwich,
58
+ ramen, ravioli, red_velvet_cake, risotto, samosa, sashimi, scallops, seaweed_salad,
59
+ shrimp_and_grits, spaghetti_bolognese, spaghetti_carbonara, spring_rolls, steak,
60
+ strawberry_shortcake, sushi, tacos, takoyaki, tiramisu, tuna_tartare, waffles
61
+ ```
62
+
63
+ ## 🛠️ Technical Details
64
+
65
+ ### Architecture
66
+ - **Base Model**: EfficientNet-B0 (pre-trained on ImageNet)
67
+ - **Fine-tuning**: Transfer learning with Food-101 dataset
68
+ - **Optimization**: ONNX Runtime for fast inference
69
+ - **Input Size**: 224×224×3 RGB images
70
+
71
+ ### Training Pipeline
72
+ 1. **Data Augmentation**: Albumentations library for robust training
73
+ 2. **Transfer Learning**: Fine-tuned pre-trained EfficientNet-B0
74
+ 3. **Advanced Training**: Early stopping, gradient clipping, AMP
75
+ 4. **Validation**: 10% held-out validation set for model selection
76
+
77
+ ### Deployment Stack
78
+ - **Model Format**: ONNX (optimized for inference)
79
+ - **Backend**: Python with ONNX Runtime
80
+ - **Frontend**: Gradio web interface
81
+ - **Hosting**: Hugging Face Spaces
82
+
83
+ ## 📝 How to Use
84
+
85
+ 1. **Upload an Image**: Click on the upload area or drag & drop a food image
86
+ 2. **Set Predictions**: Choose how many top predictions you want (1-10)
87
+ 3. **Get Results**: Click "Submit" to see predictions with confidence scores
88
+ 4. **Try Examples**: Use the provided example images to test the model
89
+
90
+ ## 🧪 Model Evaluation
91
+
92
+ ### Performance Metrics
93
+ - **Accuracy**: 84.49%
94
+ - **Macro F1-Score**: 84.40%
95
+ - **Weighted F1-Score**: 84.40%
96
+ - **Top-5 Accuracy**: 96.72%
97
+
98
+ ### Most Challenging Classes
99
+ The model struggles most with:
100
+ 1. Steak (51.35% F1-score)
101
+ 2. Apple Pie (63.36% F1-score)
102
+ 3. Pork Chop (66.02% F1-score)
103
+
104
+ ## 🔬 Explainability
105
+
106
+ The model includes Grad-CAM visualization capabilities to show which parts of the image the AI focuses on when making predictions, providing transparency into the decision-making process.
107
+
108
+ ## 📜 License
109
+
110
+ This project is licensed under the MIT License - see the LICENSE file for details.
111
+
112
+ ## 🙏 Acknowledgments
113
+
114
+ - **Dataset**: Food-101 dataset by Bossard et al.
115
+ - **Framework**: PyTorch and torchvision
116
+ - **Optimization**: ONNX Runtime
117
+ - **Interface**: Gradio
118
+ - **Hosting**: Hugging Face Spaces
119
+
120
+ ---
121
+
122
+ Built with ❤️ using PyTorch, ONNX, and Gradio
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Spaces deployment app for Food-101 classification."""
2
+
3
+ # This is the main app file expected by Hugging Face Spaces
4
+ # It imports and runs the Gradio app from gradio_app/app.py
5
+
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add current directory to path
10
+ sys.path.append(str(Path(__file__).parent))
11
+
12
+ # Import the Gradio app
13
+ from gradio_app.app import GradioFood101App
14
+
15
+ def main():
16
+ """Main function for Hugging Face Spaces deployment."""
17
+ try:
18
+ # Initialize the app
19
+ print("[HF SPACES] Initializing Food-101 Classifier App...")
20
+ app = GradioFood101App()
21
+
22
+ # Create interface
23
+ print("[HF SPACES] Creating Gradio interface...")
24
+ interface = app.create_interface()
25
+
26
+ # Launch the app for HF Spaces
27
+ print("[HF SPACES] Launching app for Hugging Face Spaces...")
28
+ interface.launch(
29
+ share=False,
30
+ server_name="0.0.0.0",
31
+ server_port=7860,
32
+ show_error=True,
33
+ enable_queue=True
34
+ )
35
+
36
+ except Exception as e:
37
+ print(f"[ERROR] Failed to launch HF Spaces app: {e}")
38
+ raise
39
+
40
+ if __name__ == "__main__":
41
+ main()
food-101/food-101/images/hamburger/100057.jpg ADDED
food-101/food-101/images/ice_cream/1004744.jpg ADDED
food-101/food-101/images/pizza/1001116.jpg ADDED
food-101/food-101/meta/classes.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheesecake
18
+ cheese_plate
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
food-101/food-101/meta/labels.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apple pie
2
+ Baby back ribs
3
+ Baklava
4
+ Beef carpaccio
5
+ Beef tartare
6
+ Beet salad
7
+ Beignets
8
+ Bibimbap
9
+ Bread pudding
10
+ Breakfast burrito
11
+ Bruschetta
12
+ Caesar salad
13
+ Cannoli
14
+ Caprese salad
15
+ Carrot cake
16
+ Ceviche
17
+ Cheesecake
18
+ Cheese plate
19
+ Chicken curry
20
+ Chicken quesadilla
21
+ Chicken wings
22
+ Chocolate cake
23
+ Chocolate mousse
24
+ Churros
25
+ Clam chowder
26
+ Club sandwich
27
+ Crab cakes
28
+ Creme brulee
29
+ Croque madame
30
+ Cup cakes
31
+ Deviled eggs
32
+ Donuts
33
+ Dumplings
34
+ Edamame
35
+ Eggs benedict
36
+ Escargots
37
+ Falafel
38
+ Filet mignon
39
+ Fish and chips
40
+ Foie gras
41
+ French fries
42
+ French onion soup
43
+ French toast
44
+ Fried calamari
45
+ Fried rice
46
+ Frozen yogurt
47
+ Garlic bread
48
+ Gnocchi
49
+ Greek salad
50
+ Grilled cheese sandwich
51
+ Grilled salmon
52
+ Guacamole
53
+ Gyoza
54
+ Hamburger
55
+ Hot and sour soup
56
+ Hot dog
57
+ Huevos rancheros
58
+ Hummus
59
+ Ice cream
60
+ Lasagna
61
+ Lobster bisque
62
+ Lobster roll sandwich
63
+ Macaroni and cheese
64
+ Macarons
65
+ Miso soup
66
+ Mussels
67
+ Nachos
68
+ Omelette
69
+ Onion rings
70
+ Oysters
71
+ Pad thai
72
+ Paella
73
+ Pancakes
74
+ Panna cotta
75
+ Peking duck
76
+ Pho
77
+ Pizza
78
+ Pork chop
79
+ Poutine
80
+ Prime rib
81
+ Pulled pork sandwich
82
+ Ramen
83
+ Ravioli
84
+ Red velvet cake
85
+ Risotto
86
+ Samosa
87
+ Sashimi
88
+ Scallops
89
+ Seaweed salad
90
+ Shrimp and grits
91
+ Spaghetti bolognese
92
+ Spaghetti carbonara
93
+ Spring rolls
94
+ Steak
95
+ Strawberry shortcake
96
+ Sushi
97
+ Tacos
98
+ Takoyaki
99
+ Tiramisu
100
+ Tuna tartare
101
+ Waffles
food-101/food-101/meta/test.json ADDED
The diff for this file is too large to render. See raw diff
 
food-101/food-101/meta/test.txt ADDED
The diff for this file is too large to render. See raw diff
 
food-101/food-101/meta/train.json ADDED
The diff for this file is too large to render. See raw diff
 
food-101/food-101/meta/train.txt ADDED
The diff for this file is too large to render. See raw diff
 
gradio_app/app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio demo app for Food-101 classification."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+ from typing import Tuple, Dict, List
6
+ import time
7
+ import tempfile
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ # Add scripts directory to path
14
+ project_root = Path(__file__).parent.parent
15
+ sys.path.append(str(project_root / "scripts"))
16
+
17
+ from predict import Food101Predictor
18
+ from train import load_food101_splits
19
+
20
+
21
+ class GradioFood101App:
22
+ """Gradio application for Food-101 classification."""
23
+
24
+ def __init__(self):
25
+ """Initialize the Gradio app with the ONNX predictor."""
26
+ self.predictor = None
27
+ self.load_model()
28
+
29
+ def load_model(self):
30
+ """Load the ONNX predictor."""
31
+ try:
32
+ # Paths
33
+ model_path = project_root / "models/efficientnet_b0_food101.onnx"
34
+ data_dir = project_root / "food-101/food-101"
35
+
36
+ # Load class names
37
+ _, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=42)
38
+ class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
39
+
40
+ # Initialize predictor
41
+ self.predictor = Food101Predictor(model_path, class_names)
42
+ print(f"[GRADIO] Model loaded successfully with {len(class_names)} classes")
43
+
44
+ except Exception as e:
45
+ print(f"[ERROR] Failed to load model: {e}")
46
+ raise
47
+
48
+ def predict_image(self, image: Image.Image, top_k: int = 5) -> Tuple[Dict, str]:
49
+ """
50
+ Predict food class for uploaded image.
51
+
52
+ Args:
53
+ image: PIL Image
54
+ top_k: Number of top predictions
55
+
56
+ Returns:
57
+ (confidences_dict, info_text)
58
+ """
59
+ if image is None:
60
+ return {}, "Please upload an image first!"
61
+
62
+ if self.predictor is None:
63
+ return {}, "Model not loaded. Please try again."
64
+
65
+ try:
66
+ # Save image temporarily
67
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
68
+ image.save(tmp_file.name)
69
+ temp_path = Path(tmp_file.name)
70
+
71
+ # Run prediction
72
+ start_time = time.time()
73
+ predictions, probabilities, inference_time = self.predictor.predict(temp_path, top_k)
74
+ total_time = (time.time() - start_time) * 1000
75
+
76
+ # Clean up
77
+ temp_path.unlink(missing_ok=True)
78
+
79
+ # Format results for Gradio
80
+ confidences = {}
81
+ for pred, prob in zip(predictions, probabilities):
82
+ confidences[pred.replace('_', ' ').title()] = float(prob)
83
+
84
+ # Create info text
85
+ info_lines = [
86
+ f"🔍 **Prediction Results**",
87
+ f"⚡ **Inference Time**: {inference_time:.2f}ms",
88
+ f"🕒 **Total Time**: {total_time:.2f}ms",
89
+ f"🧠 **Model**: EfficientNet-B0 (ONNX)",
90
+ f"📊 **Top Prediction**: {predictions[0].replace('_', ' ').title()} ({probabilities[0]*100:.1f}%)"
91
+ ]
92
+
93
+ info_text = "\n".join(info_lines)
94
+
95
+ return confidences, info_text
96
+
97
+ except Exception as e:
98
+ temp_path.unlink(missing_ok=True)
99
+ return {}, f"❌ **Error**: {str(e)}"
100
+
101
+ def get_examples(self) -> List[List]:
102
+ """Get example images for the demo."""
103
+ examples_dir = project_root / "food-101/food-101/images"
104
+ examples = []
105
+
106
+ # Select a few example images from different classes
107
+ example_classes = ['pizza', 'hamburger', 'ice_cream']
108
+
109
+ for class_name in example_classes:
110
+ class_dir = examples_dir / class_name
111
+ if class_dir.exists():
112
+ # Get first image from class
113
+ images = list(class_dir.glob("*.jpg"))
114
+ if images:
115
+ # Format: [image_path, top_k_value]
116
+ examples.append([str(images[0]), 5])
117
+
118
+ # If no examples found, return empty list (Gradio will handle gracefully)
119
+ return examples if examples else []
120
+
121
+ def create_interface(self) -> gr.Interface:
122
+ """Create and return the Gradio interface."""
123
+
124
+ # Custom CSS for better styling
125
+ css = """
126
+ .main-header {
127
+ text-align: center;
128
+ background: linear-gradient(90deg, #ff6b6b, #4ecdc4);
129
+ -webkit-background-clip: text;
130
+ -webkit-text-fill-color: transparent;
131
+ font-size: 2.5em;
132
+ font-weight: bold;
133
+ margin-bottom: 20px;
134
+ }
135
+ .info-box {
136
+ background-color: #f0f8ff;
137
+ border-left: 5px solid #4ecdc4;
138
+ padding: 15px;
139
+ margin: 10px 0;
140
+ border-radius: 5px;
141
+ }
142
+ """
143
+
144
+ # Interface description
145
+ description = """
146
+ ## 🍕 Food-101 Image Classifier
147
+
148
+ Upload an image of food and get AI-powered predictions! This demo uses a fine-tuned **EfficientNet-B0** model
149
+ trained on the Food-101 dataset to classify 101 different types of food.
150
+
151
+ ### 🎯 **Model Performance**
152
+ - **Accuracy**: 84.49% on test set
153
+ - **Inference Speed**: ~7ms per image
154
+ - **Classes**: 101 different food types
155
+
156
+ ### 🚀 **How to use**
157
+ 1. Upload an image or try one of our examples
158
+ 2. Adjust the number of top predictions (1-10)
159
+ 3. Click Submit to get predictions with confidence scores!
160
+ """
161
+
162
+ # Create the interface
163
+ interface = gr.Interface(
164
+ fn=self.predict_image,
165
+ inputs=[
166
+ gr.Image(
167
+ type="pil",
168
+ label="📸 Upload Food Image",
169
+ height=300
170
+ ),
171
+ gr.Slider(
172
+ minimum=1,
173
+ maximum=10,
174
+ value=5,
175
+ step=1,
176
+ label="🔢 Number of Predictions"
177
+ )
178
+ ],
179
+ outputs=[
180
+ gr.Label(
181
+ label="🏆 Predictions & Confidence Scores",
182
+ num_top_classes=10
183
+ ),
184
+ gr.Markdown(
185
+ label="📊 Prediction Details"
186
+ )
187
+ ],
188
+ title="🍔 Food-101 AI Classifier",
189
+ description=description,
190
+ examples=self.get_examples(),
191
+ css=css,
192
+ theme=gr.themes.Soft(),
193
+ flagging_mode="never"
194
+ )
195
+
196
+ return interface
197
+
198
+
199
+ def main():
200
+ """Main function to launch the Gradio app."""
201
+ try:
202
+ # Initialize the app
203
+ print("[GRADIO] Initializing Food-101 Classifier App...")
204
+ app = GradioFood101App()
205
+
206
+ # Create interface
207
+ print("[GRADIO] Creating Gradio interface...")
208
+ interface = app.create_interface()
209
+
210
+ # Launch the app
211
+ print("[GRADIO] Launching app...")
212
+ interface.launch(
213
+ share=False, # Set to True to create public link
214
+ server_name="0.0.0.0",
215
+ server_port=7860,
216
+ show_error=True
217
+ )
218
+
219
+ except Exception as e:
220
+ print(f"[ERROR] Failed to launch Gradio app: {e}")
221
+ raise
222
+
223
+
224
+ if __name__ == "__main__":
225
+ main()
models/efficientnet_b0_food101.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a19fedaec8bb25ca7af8d49a71aaf2d5f71588bd96d859c252fd1e4902345179
3
+ size 16537732
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for Hugging Face Spaces deployment
2
+ # Core dependencies for Food-101 classifier
3
+
4
+ # Deep Learning & Computer Vision
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ onnxruntime>=1.20.0
8
+ onnx>=1.15.0
9
+
10
+ # Image Processing
11
+ pillow>=9.0.0
12
+ opencv-python>=4.5.0
13
+ albumentations>=1.3.0
14
+
15
+ # ML & Data Science
16
+ numpy>=1.24.0
17
+ scikit-learn>=1.3.0
18
+
19
+ # Web Interface
20
+ gradio>=4.0.0
21
+
22
+ # Utilities
23
+ pathlib
scripts/__pycache__/check_gpu.cpython-312.pyc ADDED
Binary file (1.37 kB). View file
 
scripts/__pycache__/evaluate.cpython-312.pyc ADDED
Binary file (9.39 kB). View file
 
scripts/__pycache__/gradcam.cpython-312.pyc ADDED
Binary file (9.18 kB). View file
 
scripts/__pycache__/predict.cpython-312.pyc ADDED
Binary file (9.5 kB). View file
 
scripts/__pycache__/train.cpython-312.pyc ADDED
Binary file (28.6 kB). View file
 
scripts/check_gpu.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility script to report CUDA GPU availability for PyTorch."""
2
+ from __future__ import annotations
3
+
4
+ import torch
5
+
6
+ def main() -> None:
7
+ has_cuda = torch.cuda.is_available()
8
+ print(f"torch.cuda.is_available(): {has_cuda}")
9
+ if has_cuda:
10
+ num_devices = torch.cuda.device_count()
11
+ print(f"Detected CUDA devices: {num_devices}")
12
+ for idx in range(num_devices):
13
+ name = torch.cuda.get_device_name(idx)
14
+ capability = torch.cuda.get_device_capability(idx)
15
+ print(f" - Device {idx}: {name} (compute capability {capability[0]}.{capability[1]})")
16
+ else:
17
+ print("No CUDA-capable GPU detected. Training will fall back to CPU.")
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
scripts/evaluate.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation utilities for Food-101 classifiers (Phase 5)."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Dict, List, Sequence
8
+
9
+ import albumentations as A
10
+ import numpy as np
11
+ import torch
12
+ from albumentations.pytorch import ToTensorV2
13
+ from PIL import Image
14
+ from sklearn.metrics import classification_report, confusion_matrix, top_k_accuracy_score
15
+ from torch.utils.data import DataLoader, Dataset
16
+
17
+ from train import BaselineCNN, Sample, build_model, load_food101_splits, set_seed
18
+
19
+
20
+ class Food101EvalDataset(Dataset):
21
+ """Thin dataset wrapper that applies evaluation transforms."""
22
+
23
+ def __init__(self, samples: Sequence[Sample], transform: A.BasicTransform) -> None:
24
+ self.samples = list(samples)
25
+ self.transform = transform
26
+
27
+ def __len__(self) -> int:
28
+ return len(self.samples)
29
+
30
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor | int | str]:
31
+ sample = self.samples[index]
32
+ with sample.path.open("rb") as file:
33
+ array = np.array(Image.open(file).convert("RGB"))
34
+ tensor = self.transform(image=array)["image"]
35
+ return {"image": tensor, "label": sample.label, "path": str(sample.path)}
36
+
37
+
38
+ def parse_args() -> argparse.Namespace:
39
+ parser = argparse.ArgumentParser(description="Evaluate Food-101 checkpoints")
40
+ parser.add_argument("--data-dir", type=Path, default=Path("data/raw/food-101"), help="Dataset root")
41
+ parser.add_argument("--checkpoint", type=Path, required=True, help="Checkpoint file to evaluate")
42
+ parser.add_argument("--model", choices=["baseline", "resnet50", "efficientnet_b0"], required=True)
43
+ parser.add_argument("--split", choices=["val", "test"], default="test", help="Dataset split to use")
44
+ parser.add_argument("--batch-size", type=int, default=64)
45
+ parser.add_argument("--num-workers", type=int, default=4)
46
+ parser.add_argument("--image-size", type=int, default=224)
47
+ parser.add_argument("--seed", type=int, default=42)
48
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
49
+ parser.add_argument("--topk", nargs="*", type=int, default=[1, 5], help="Top-k accuracies to report")
50
+ parser.add_argument("--report-json", type=Path, default=None, help="Optional path to dump JSON metrics")
51
+ return parser.parse_args()
52
+
53
+
54
+ def build_eval_transform(image_size: int) -> A.BasicTransform:
55
+ return A.Compose(
56
+ [
57
+ A.Resize(height=image_size + 32, width=image_size + 32),
58
+ A.CenterCrop(height=image_size, width=image_size),
59
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
60
+ ToTensorV2(),
61
+ ]
62
+ )
63
+
64
+
65
+ def run_evaluation(args: argparse.Namespace) -> Dict[str, float]:
66
+ set_seed(args.seed)
67
+ device = torch.device(args.device)
68
+
69
+ data_dir = args.data_dir.expanduser().resolve()
70
+ train_samples, val_samples, test_samples, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=args.seed)
71
+ class_names: List[str] = [idx_to_class[i] for i in range(len(idx_to_class))]
72
+
73
+ split_samples = val_samples if args.split == "val" else test_samples
74
+ transform = build_eval_transform(args.image_size)
75
+ dataset = Food101EvalDataset(split_samples, transform=transform)
76
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=torch.cuda.is_available())
77
+
78
+ model = build_model(args.model, num_classes=len(class_names), pretrained=False, freeze_backbone=False)
79
+ try:
80
+ state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
81
+ except TypeError:
82
+ state_dict = torch.load(args.checkpoint, map_location=device)
83
+ model.load_state_dict(state_dict)
84
+ model.to(device)
85
+ model.eval()
86
+
87
+ all_probs: List[torch.Tensor] = []
88
+ all_labels: List[int] = []
89
+
90
+ with torch.no_grad():
91
+ for batch in dataloader:
92
+ inputs = batch["image"].to(device)
93
+ outputs = model(inputs)
94
+ probs = torch.softmax(outputs, dim=1)
95
+ all_probs.append(probs.cpu())
96
+ all_labels.extend(batch["label"].tolist())
97
+
98
+ probs_tensor = torch.cat(all_probs, dim=0)
99
+ preds = probs_tensor.argmax(dim=1).numpy()
100
+ labels_np = np.array(all_labels)
101
+
102
+ report = classification_report(labels_np, preds, target_names=class_names, output_dict=True, zero_division=0)
103
+ conf_mat = confusion_matrix(labels_np, preds)
104
+
105
+ metrics: Dict[str, float] = {
106
+ "accuracy": report["accuracy"],
107
+ "macro_precision": report["macro avg"]["precision"],
108
+ "macro_recall": report["macro avg"]["recall"],
109
+ "macro_f1": report["macro avg"]["f1-score"],
110
+ "weighted_f1": report["weighted avg"]["f1-score"],
111
+ }
112
+
113
+ for k in args.topk:
114
+ metrics[f"top{k}_accuracy"] = top_k_accuracy_score(labels_np, probs_tensor, k=k, labels=list(range(len(class_names))))
115
+
116
+ if args.report_json:
117
+ args.report_json.parent.mkdir(parents=True, exist_ok=True)
118
+ with args.report_json.open("w") as f:
119
+ json.dump({"metrics": metrics, "classification_report": report, "confusion_matrix": conf_mat.tolist()}, f, indent=2)
120
+
121
+ print("=== Metrics ===")
122
+ for key, value in metrics.items():
123
+ print(f"{key}: {value:.4f}")
124
+
125
+ print("=== Confusion Matrix (sample) ===")
126
+ print(conf_mat)
127
+ return metrics
128
+
129
+
130
+ def main() -> None:
131
+ args = parse_args()
132
+ run_evaluation(args)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
scripts/export_model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Export PyTorch models to ONNX format for optimized inference."""
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ import torch
6
+ import onnx
7
+ import onnxruntime as ort
8
+ import numpy as np
9
+
10
+ from train import build_model, load_food101_splits, set_seed
11
+
12
+
13
+ def parse_args() -> argparse.Namespace:
14
+ parser = argparse.ArgumentParser(description="Export PyTorch model to ONNX")
15
+ parser.add_argument("--checkpoint", type=Path, required=True, help="Path to PyTorch checkpoint")
16
+ parser.add_argument("--model", choices=["baseline", "resnet50", "efficientnet_b0"], required=True)
17
+ parser.add_argument("--output", type=Path, required=True, help="Output ONNX file path")
18
+ parser.add_argument("--data-dir", type=Path, default=Path("food-101/food-101"), help="Dataset root")
19
+ parser.add_argument("--input-size", type=int, default=224, help="Input image size")
20
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size for export")
21
+ parser.add_argument("--device", type=str, default="cpu", help="Device for export")
22
+ parser.add_argument("--opset-version", type=int, default=11, help="ONNX opset version")
23
+ parser.add_argument("--seed", type=int, default=42)
24
+ return parser.parse_args()
25
+
26
+
27
+ def export_to_onnx(
28
+ model: torch.nn.Module,
29
+ output_path: Path,
30
+ input_size: int,
31
+ batch_size: int,
32
+ device: torch.device,
33
+ opset_version: int = 11
34
+ ) -> None:
35
+ """Export PyTorch model to ONNX format."""
36
+
37
+ model.eval()
38
+
39
+ # Create dummy input tensor
40
+ dummy_input = torch.randn(batch_size, 3, input_size, input_size, device=device)
41
+
42
+ # Export to ONNX
43
+ output_path.parent.mkdir(parents=True, exist_ok=True)
44
+
45
+ torch.onnx.export(
46
+ model,
47
+ dummy_input,
48
+ str(output_path),
49
+ export_params=True,
50
+ opset_version=opset_version,
51
+ do_constant_folding=True,
52
+ input_names=['input'],
53
+ output_names=['output'],
54
+ dynamic_axes={
55
+ 'input': {0: 'batch_size'},
56
+ 'output': {0: 'batch_size'}
57
+ }
58
+ )
59
+
60
+ print(f"[SUCCESS] Model exported to {output_path}")
61
+
62
+
63
+ def verify_onnx_model(onnx_path: Path, pytorch_model: torch.nn.Module, input_size: int, device: torch.device) -> None:
64
+ """Verify that ONNX model produces same outputs as PyTorch model."""
65
+
66
+ # Load ONNX model
67
+ onnx_model = onnx.load(str(onnx_path))
68
+ onnx.checker.check_model(onnx_model)
69
+ print("[SUCCESS] ONNX model is valid")
70
+
71
+ # Create ONNX Runtime session
72
+ ort_session = ort.InferenceSession(str(onnx_path))
73
+
74
+ # Create test input
75
+ test_input = torch.randn(1, 3, input_size, input_size, device=device)
76
+
77
+ # PyTorch inference
78
+ pytorch_model.eval()
79
+ with torch.no_grad():
80
+ pytorch_output = pytorch_model(test_input).cpu().numpy()
81
+
82
+ # ONNX inference
83
+ onnx_input = test_input.cpu().numpy()
84
+ onnx_output = ort_session.run(['output'], {'input': onnx_input})[0]
85
+
86
+ # Compare outputs
87
+ max_diff = np.max(np.abs(pytorch_output - onnx_output))
88
+ print(f"[SUCCESS] Max difference between PyTorch and ONNX outputs: {max_diff:.6f}")
89
+
90
+ if max_diff < 1e-5:
91
+ print("[SUCCESS] ONNX model verification successful!")
92
+ else:
93
+ print("[WARNING] Large difference detected - verify model compatibility")
94
+
95
+
96
+ def get_model_info(model: torch.nn.Module, input_size: int) -> None:
97
+ """Print model information."""
98
+
99
+ # Count parameters
100
+ total_params = sum(p.numel() for p in model.parameters())
101
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
102
+
103
+ print(f"Model Information:")
104
+ print(f" Total parameters: {total_params:,}")
105
+ print(f" Trainable parameters: {trainable_params:,}")
106
+ print(f" Input size: {input_size}x{input_size}x3")
107
+
108
+
109
+ def main() -> None:
110
+ args = parse_args()
111
+ set_seed(args.seed)
112
+
113
+ device = torch.device(args.device)
114
+
115
+ # Load dataset info to get number of classes
116
+ data_dir = args.data_dir.expanduser().resolve()
117
+ _, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=args.seed)
118
+ num_classes = len(idx_to_class)
119
+
120
+ # Load model
121
+ print(f"[INFO] Loading {args.model} model...")
122
+ model = build_model(args.model, num_classes=num_classes, pretrained=False, freeze_backbone=False)
123
+
124
+ # Load checkpoint
125
+ try:
126
+ state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
127
+ except TypeError:
128
+ state_dict = torch.load(args.checkpoint, map_location=device)
129
+
130
+ model.load_state_dict(state_dict)
131
+ model.to(device)
132
+
133
+ # Print model info
134
+ get_model_info(model, args.input_size)
135
+
136
+ # Export to ONNX
137
+ print(f"[INFO] Exporting to ONNX...")
138
+ export_to_onnx(
139
+ model=model,
140
+ output_path=args.output,
141
+ input_size=args.input_size,
142
+ batch_size=args.batch_size,
143
+ device=device,
144
+ opset_version=args.opset_version
145
+ )
146
+
147
+ # Verify ONNX model
148
+ print(f"[INFO] Verifying ONNX model...")
149
+ verify_onnx_model(args.output, model, args.input_size, device)
150
+
151
+ # Print file size
152
+ file_size_mb = args.output.stat().st_size / (1024 * 1024)
153
+ print(f"[INFO] ONNX file size: {file_size_mb:.2f} MB")
154
+
155
+ print("[SUCCESS] Export completed successfully!")
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
scripts/gradcam.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grad-CAM utilities for Food-101 models."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+ from typing import Dict, List
7
+
8
+ import albumentations as A
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from albumentations.pytorch import ToTensorV2
13
+ from PIL import Image
14
+
15
+ from train import build_model, load_food101_splits, set_seed
16
+
17
+
18
+ def parse_args() -> argparse.Namespace:
19
+ parser = argparse.ArgumentParser(description="Generate Grad-CAM visualizations")
20
+ parser.add_argument("--data-dir", type=Path, default=Path("data/raw/food-101"))
21
+ parser.add_argument("--checkpoint", type=Path, required=True)
22
+ parser.add_argument("--model", choices=["baseline", "resnet50", "efficientnet_b0"], required=True)
23
+ parser.add_argument("--class-index", type=int, default=None, help="Target class index for Grad-CAM (defaults to prediction)")
24
+ parser.add_argument("--image", type=Path, required=True, help="Path to input image file")
25
+ parser.add_argument("--output", type=Path, required=True, help="Output heatmap path")
26
+ parser.add_argument("--image-size", type=int, default=224)
27
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
28
+ parser.add_argument("--seed", type=int, default=42)
29
+ parser.add_argument("--alpha", type=float, default=0.4, help="Blending factor for heatmap overlay")
30
+ return parser.parse_args()
31
+
32
+
33
+ def build_preprocess(image_size: int) -> A.BasicTransform:
34
+ return A.Compose(
35
+ [
36
+ A.Resize(height=image_size, width=image_size),
37
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
38
+ ToTensorV2(),
39
+ ]
40
+ )
41
+
42
+
43
+ def get_last_conv_layer(model: torch.nn.Module) -> torch.nn.Module:
44
+ # EfficientNet B0
45
+ if hasattr(model, "features") and hasattr(model.features, "_modules"):
46
+ return model.features[-1][-1] # Last block, last layer
47
+
48
+ # ResNet50
49
+ if hasattr(model, "layer4"):
50
+ return model.layer4[-1].conv3 # type: ignore[attr-defined]
51
+
52
+ # Baseline CNN
53
+ if hasattr(model, "features"):
54
+ for module in reversed(model.features):
55
+ if isinstance(module, torch.nn.Conv2d):
56
+ return module
57
+
58
+ # Generic fallback
59
+ if hasattr(model, "classifier") and isinstance(model.classifier, torch.nn.Sequential):
60
+ for module in reversed(model.classifier):
61
+ if isinstance(module, torch.nn.Conv2d):
62
+ return module
63
+
64
+ raise RuntimeError("Could not automatically determine last convolutional layer")
65
+
66
+
67
+ def generate_gradcam(
68
+ model: torch.nn.Module,
69
+ image_tensor: torch.Tensor,
70
+ target_class: int | None,
71
+ ) -> np.ndarray:
72
+ gradients: List[torch.Tensor] = []
73
+ activations: List[torch.Tensor] = []
74
+
75
+ def backward_hook(module: torch.nn.Module, grad_input: tuple[torch.Tensor, ...], grad_output: tuple[torch.Tensor, ...]) -> None:
76
+ gradients.append(grad_output[0])
77
+
78
+ def forward_hook(module: torch.nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
79
+ activations.append(output)
80
+
81
+ target_layer = get_last_conv_layer(model)
82
+ handle_fwd = target_layer.register_forward_hook(forward_hook)
83
+ handle_bwd = target_layer.register_full_backward_hook(backward_hook)
84
+
85
+ try:
86
+ output = model(image_tensor)
87
+ if target_class is None:
88
+ target_class = int(output.argmax(dim=1).item())
89
+ loss = output[:, target_class].sum()
90
+ model.zero_grad()
91
+ loss.backward()
92
+
93
+ grads = gradients[0]
94
+ acts = activations[0]
95
+ weights = grads.mean(dim=(2, 3), keepdim=True)
96
+ cam = torch.relu((weights * acts).sum(dim=1, keepdim=True))
97
+ cam = torch.nn.functional.interpolate(cam, size=image_tensor.shape[2:], mode="bilinear", align_corners=False)
98
+ cam = cam.squeeze().detach().cpu().numpy()
99
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
100
+ return cam
101
+ finally:
102
+ handle_fwd.remove()
103
+ handle_bwd.remove()
104
+
105
+
106
+ def overlay_heatmap(original: np.ndarray, heatmap: np.ndarray, alpha: float) -> np.ndarray:
107
+ # Resize heatmap to match original image size
108
+ heatmap_resized = cv2.resize(heatmap, (original.shape[1], original.shape[0]))
109
+ heatmap_color = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
110
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
111
+ overlay = cv2.addWeighted(heatmap_color, alpha, original, 1 - alpha, 0)
112
+ return overlay
113
+
114
+
115
+ def main() -> None:
116
+ args = parse_args()
117
+ set_seed(args.seed)
118
+
119
+ device = torch.device(args.device)
120
+ model = build_model(args.model, num_classes=101, pretrained=False, freeze_backbone=False)
121
+ try:
122
+ state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
123
+ except TypeError:
124
+ state_dict = torch.load(args.checkpoint, map_location=device)
125
+ model.load_state_dict(state_dict)
126
+ model.to(device)
127
+ model.eval()
128
+
129
+ preprocess = build_preprocess(args.image_size)
130
+ with args.image.open("rb") as f:
131
+ original = np.array(Image.open(f).convert("RGB"))
132
+ tensor = preprocess(image=original)["image"].unsqueeze(0).to(device)
133
+
134
+ heatmap = generate_gradcam(model, tensor, args.class_index)
135
+ overlay = overlay_heatmap(original, heatmap, alpha=args.alpha)
136
+
137
+ args.output.parent.mkdir(parents=True, exist_ok=True)
138
+ Image.fromarray(overlay).save(args.output)
139
+ print(f"Grad-CAM saved to {args.output}")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
scripts/predict.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fast inference script using ONNX model for Food-101 classification."""
2
+
3
+ import argparse
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Dict, List, Tuple
7
+
8
+ import albumentations as A
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ from albumentations.pytorch import ToTensorV2
12
+ from PIL import Image
13
+
14
+ from train import load_food101_splits
15
+
16
+
17
+ class Food101Predictor:
18
+ """Fast ONNX-based predictor for Food-101 classification."""
19
+
20
+ def __init__(self, onnx_path: Path, class_names: List[str], providers: List[str] = None):
21
+ """Initialize predictor with ONNX model."""
22
+
23
+ self.class_names = class_names
24
+ self.num_classes = len(class_names)
25
+
26
+ # Initialize ONNX Runtime session with optimal providers
27
+ if providers is None:
28
+ providers = ['CPUExecutionProvider']
29
+
30
+ self.session = ort.InferenceSession(str(onnx_path), providers=providers)
31
+
32
+ # Get input/output info
33
+ self.input_name = self.session.get_inputs()[0].name
34
+ self.output_name = self.session.get_outputs()[0].name
35
+
36
+ # Create preprocessing transform
37
+ self.transform = A.Compose([
38
+ A.Resize(height=224, width=224),
39
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
40
+ ToTensorV2(),
41
+ ])
42
+
43
+ print(f"[INFO] Predictor initialized with {len(class_names)} classes")
44
+ print(f"[INFO] ONNX Runtime providers: {self.session.get_providers()}")
45
+
46
+ def preprocess_image(self, image_path: Path) -> np.ndarray:
47
+ """Preprocess image for inference."""
48
+
49
+ with image_path.open("rb") as f:
50
+ image = Image.open(f).convert("RGB")
51
+
52
+ # Convert to numpy array
53
+ image_array = np.array(image)
54
+
55
+ # Apply transforms
56
+ transformed = self.transform(image=image_array)
57
+ tensor = transformed["image"]
58
+
59
+ # Add batch dimension and convert to numpy
60
+ batch = tensor.unsqueeze(0).numpy()
61
+ return batch
62
+
63
+ def predict(self, image_path: Path, top_k: int = 5) -> Tuple[List[str], List[float], float]:
64
+ """
65
+ Predict food class for given image.
66
+
67
+ Returns:
68
+ predictions: List of top-k class names
69
+ probabilities: List of top-k probabilities
70
+ inference_time: Time in milliseconds
71
+ """
72
+
73
+ # Preprocess
74
+ start_time = time.time()
75
+ input_batch = self.preprocess_image(image_path)
76
+
77
+ # Run inference
78
+ outputs = self.session.run([self.output_name], {self.input_name: input_batch})[0]
79
+
80
+ # Apply softmax to get probabilities
81
+ exp_outputs = np.exp(outputs - np.max(outputs, axis=1, keepdims=True))
82
+ probabilities = exp_outputs / np.sum(exp_outputs, axis=1, keepdims=True)
83
+
84
+ # Get top-k predictions
85
+ top_indices = np.argsort(probabilities[0])[::-1][:top_k]
86
+ top_probs = probabilities[0][top_indices].tolist()
87
+ top_classes = [self.class_names[i] for i in top_indices]
88
+
89
+ inference_time = (time.time() - start_time) * 1000 # Convert to ms
90
+
91
+ return top_classes, top_probs, inference_time
92
+
93
+ def predict_batch(self, image_paths: List[Path], top_k: int = 5) -> List[Dict]:
94
+ """Predict multiple images at once."""
95
+
96
+ results = []
97
+ for image_path in image_paths:
98
+ classes, probs, time_ms = self.predict(image_path, top_k)
99
+ results.append({
100
+ 'image_path': str(image_path),
101
+ 'predictions': classes,
102
+ 'probabilities': probs,
103
+ 'inference_time_ms': time_ms
104
+ })
105
+
106
+ return results
107
+
108
+
109
+ def parse_args() -> argparse.Namespace:
110
+ parser = argparse.ArgumentParser(description="Fast inference with ONNX model")
111
+ parser.add_argument("--model", type=Path, required=True, help="Path to ONNX model file")
112
+ parser.add_argument("--image", type=Path, required=True, help="Path to input image")
113
+ parser.add_argument("--data-dir", type=Path, default=Path("food-101/food-101"), help="Dataset root for class names")
114
+ parser.add_argument("--top-k", type=int, default=5, help="Number of top predictions to show")
115
+ parser.add_argument("--providers", nargs="*", default=None,
116
+ help="ONNX Runtime providers (e.g., CPUExecutionProvider)")
117
+ parser.add_argument("--seed", type=int, default=42)
118
+ return parser.parse_args()
119
+
120
+
121
+ def benchmark_inference(predictor: Food101Predictor, image_path: Path, num_runs: int = 10) -> None:
122
+ """Benchmark inference speed."""
123
+
124
+ print(f"[INFO] Benchmarking inference with {num_runs} runs...")
125
+
126
+ times = []
127
+ for i in range(num_runs):
128
+ _, _, inference_time = predictor.predict(image_path, top_k=1)
129
+ times.append(inference_time)
130
+ if i == 0:
131
+ print(f"[INFO] First run (cold start): {inference_time:.2f}ms")
132
+
133
+ # Statistics
134
+ mean_time = np.mean(times[1:]) # Exclude cold start
135
+ std_time = np.std(times[1:])
136
+ min_time = np.min(times[1:])
137
+ max_time = np.max(times[1:])
138
+
139
+ print(f"[BENCHMARK] Average inference time: {mean_time:.2f} ± {std_time:.2f}ms")
140
+ print(f"[BENCHMARK] Min: {min_time:.2f}ms, Max: {max_time:.2f}ms")
141
+
142
+ if mean_time < 100:
143
+ print(f"[SUCCESS] Target latency achieved! ({mean_time:.2f}ms < 100ms)")
144
+ else:
145
+ print(f"[WARNING] Target latency not met ({mean_time:.2f}ms >= 100ms)")
146
+
147
+
148
+ def main() -> None:
149
+ args = parse_args()
150
+
151
+ # Load class names from dataset
152
+ data_dir = args.data_dir.expanduser().resolve()
153
+ _, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=args.seed)
154
+ class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
155
+
156
+ # Initialize predictor
157
+ predictor = Food101Predictor(args.model, class_names, providers=args.providers)
158
+
159
+ # Run prediction
160
+ print(f"[INFO] Predicting image: {args.image}")
161
+ predictions, probabilities, inference_time = predictor.predict(args.image, args.top_k)
162
+
163
+ # Display results
164
+ print(f"\n[RESULTS] Inference time: {inference_time:.2f}ms")
165
+ print("Top predictions:")
166
+ for i, (class_name, prob) in enumerate(zip(predictions, probabilities), 1):
167
+ print(f" {i}. {class_name}: {prob:.4f} ({prob*100:.2f}%)")
168
+
169
+ # Run benchmark
170
+ print()
171
+ benchmark_inference(predictor, args.image)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
scripts/train.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training script for Food-101 supporting baseline, transfer, and advanced training utilities."""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import csv
6
+ import json
7
+ import random
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
11
+
12
+ import albumentations as A
13
+ import numpy as np
14
+ import torch
15
+ from albumentations.pytorch import ToTensorV2
16
+ from PIL import Image
17
+ from torch import nn
18
+ from torch.amp import GradScaler, autocast
19
+ from torch.optim import Adam
20
+ from torch.utils.data import DataLoader, Dataset
21
+ from torchvision import models
22
+ from torchvision.models import EfficientNet_B0_Weights, ResNet50_Weights
23
+
24
+ try: # Optional dependency used for experiment tracking.
25
+ import wandb # type: ignore
26
+ except ImportError: # pragma: no cover - handled at runtime when library missing.
27
+ wandb = None
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class Sample:
32
+ """Minimal container storing an image path and class index."""
33
+
34
+ path: Path
35
+ label: int
36
+
37
+
38
+ class Food101Dataset(Dataset):
39
+ """Custom Dataset that loads Food-101 images and applies augmentations."""
40
+
41
+ def __init__(self, samples: Sequence[Sample], transform: A.BasicTransform | None = None) -> None:
42
+ self.samples = list(samples)
43
+ self.transform = transform
44
+
45
+ def __len__(self) -> int:
46
+ return len(self.samples)
47
+
48
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
49
+ sample = self.samples[index]
50
+ with sample.path.open("rb") as file:
51
+ # Convert PIL image to NumPy array so Albumentations can process it.
52
+ array = np.array(Image.open(file).convert("RGB"))
53
+ if self.transform is not None:
54
+ array = self.transform(image=array)["image"]
55
+ return array, sample.label
56
+
57
+
58
+ class BaselineCNN(nn.Module):
59
+ """Lightweight CNN baseline with three feature stages and global pooling."""
60
+
61
+ def __init__(self, num_classes: int) -> None:
62
+ super().__init__()
63
+ self.features = nn.Sequential(
64
+ _conv_block(3, 32),
65
+ _conv_block(32, 64),
66
+ nn.MaxPool2d(2),
67
+ _conv_block(64, 128),
68
+ nn.MaxPool2d(2),
69
+ _conv_block(128, 256),
70
+ nn.MaxPool2d(2),
71
+ )
72
+ self.classifier = nn.Sequential(
73
+ nn.AdaptiveAvgPool2d(1),
74
+ nn.Flatten(),
75
+ nn.Dropout(0.3),
76
+ nn.Linear(256, num_classes),
77
+ )
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ x = self.features(x)
81
+ return self.classifier(x)
82
+
83
+
84
+ def _conv_block(in_channels: int, out_channels: int) -> nn.Sequential:
85
+ """Creates a Conv-BN-ReLU block reused across the network."""
86
+
87
+ return nn.Sequential(
88
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
89
+ nn.BatchNorm2d(out_channels),
90
+ nn.ReLU(inplace=True),
91
+ )
92
+
93
+
94
+ def parse_args() -> argparse.Namespace:
95
+ """Parses command line arguments controlling training behavior."""
96
+
97
+ parser = argparse.ArgumentParser(description="Train Food-101 image classifiers")
98
+ parser.add_argument("--data-dir", type=Path, default=Path("data/raw/food-101"), help="Root directory of Food-101 dataset")
99
+ parser.add_argument("--model", type=str, choices=["baseline", "resnet50", "efficientnet_b0"], default="baseline", help="Model architecture to train")
100
+ parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
101
+ parser.add_argument("--batch-size", type=int, default=64, help="Mini-batch size")
102
+ parser.add_argument("--learning-rate", type=float, default=1e-3, help="Optimizer learning rate")
103
+ parser.add_argument("--val-split", type=float, default=0.1, help="Fraction of train set used for validation")
104
+ parser.add_argument("--num-workers", type=int, default=4, help="DataLoader worker processes")
105
+ parser.add_argument("--image-size", type=int, default=224, help="Square image size fed to the network")
106
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
107
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Computation device")
108
+ parser.add_argument("--pretrained", action="store_true", default=None, help="Use pretrained weights when available (transfer models)")
109
+ parser.add_argument("--no-pretrained", action="store_false", dest="pretrained", help="Disable pretrained weights for transfer models")
110
+ parser.add_argument("--freeze-backbone", action="store_true", help="Freeze feature extractor when using transfer models")
111
+ parser.add_argument("--experiment-name", type=str, default=None, help="Optional experiment name for checkpoints/logs")
112
+ parser.add_argument("--checkpoint-dir", type=Path, default=Path("checkpoints"), help="Where to store model checkpoints")
113
+ parser.add_argument("--log-dir", type=Path, default=Path("logs"), help="Where to store training logs")
114
+ parser.add_argument("--early-stop-patience", type=int, default=5, help="Epochs to wait before stopping after no validation improvement")
115
+ parser.add_argument("--early-stop-min-delta", type=float, default=0.0, help="Minimum change to qualify as an improvement")
116
+ parser.add_argument("--early-stop-metric", choices=["accuracy", "loss"], default="accuracy", help="Validation metric used for early stopping")
117
+ parser.add_argument("--grad-clip-norm", type=float, default=None, help="Gradient clipping norm (L2). Disabled if not set")
118
+ parser.add_argument("--use-amp", action="store_true", help="Enable mixed precision training (requires CUDA)")
119
+ parser.add_argument("--wandb", dest="use_wandb", action="store_true", help="Log metrics to Weights & Biases")
120
+ parser.add_argument("--no-wandb", dest="use_wandb", action="store_false", help="Disable Weights & Biases logging")
121
+ parser.add_argument("--wandb-project", type=str, default=None, help="Weights & Biases project name")
122
+ parser.add_argument("--wandb-entity", type=str, default=None, help="Weights & Biases entity (team/user)")
123
+ parser.add_argument("--wandb-run-name", type=str, default=None, help="Weights & Biases run name override")
124
+ args = parser.parse_args()
125
+
126
+ # Default to pretrained weights for transfer models unless user explicitly disables them.
127
+ if args.pretrained is None:
128
+ args.pretrained = args.model != "baseline"
129
+ if not hasattr(args, "use_wandb"):
130
+ args.use_wandb = False
131
+ return args
132
+
133
+
134
+ def set_seed(seed: int) -> None:
135
+ """Fixes random seeds for reproducibility across libraries."""
136
+
137
+ random.seed(seed)
138
+ np.random.seed(seed)
139
+ torch.manual_seed(seed)
140
+ if torch.cuda.is_available():
141
+ torch.cuda.manual_seed_all(seed)
142
+ torch.backends.cudnn.deterministic = True
143
+ torch.backends.cudnn.benchmark = False
144
+
145
+
146
+ def build_transforms(image_size: int) -> Tuple[A.BasicTransform, A.BasicTransform]:
147
+ """Constructs augmentation pipelines for train and evaluation splits."""
148
+
149
+ normalize = A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
150
+ train_transform = A.Compose(
151
+ [
152
+ A.RandomResizedCrop(
153
+ size=(image_size, image_size),
154
+ scale=(0.8, 1.0),
155
+ ratio=(0.75, 1.33),
156
+ ),
157
+ A.HorizontalFlip(p=0.5),
158
+ A.Affine(
159
+ scale=(0.9, 1.1),
160
+ translate_percent=(-0.05, 0.05),
161
+ rotate=(-15, 15),
162
+ shear=(0.0, 0.0),
163
+ p=0.3,
164
+ ),
165
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
166
+ normalize,
167
+ ToTensorV2(),
168
+ ]
169
+ )
170
+ eval_transform = A.Compose(
171
+ [
172
+ A.Resize(height=image_size + 32, width=image_size + 32),
173
+ A.CenterCrop(height=image_size, width=image_size),
174
+ normalize,
175
+ ToTensorV2(),
176
+ ]
177
+ )
178
+ return train_transform, eval_transform
179
+
180
+
181
+ def build_model(
182
+ model_name: str,
183
+ num_classes: int,
184
+ pretrained: bool,
185
+ freeze_backbone: bool,
186
+ ) -> nn.Module:
187
+ """Factory that returns the requested architecture configured for Food-101."""
188
+
189
+ if model_name == "baseline":
190
+ model = BaselineCNN(num_classes=num_classes)
191
+ elif model_name == "resnet50":
192
+ weights = ResNet50_Weights.DEFAULT if pretrained else None
193
+ model = models.resnet50(weights=weights)
194
+ if freeze_backbone:
195
+ for param in model.parameters():
196
+ param.requires_grad = False
197
+ in_features = model.fc.in_features
198
+ model.fc = nn.Linear(in_features, num_classes)
199
+ elif model_name == "efficientnet_b0":
200
+ weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
201
+ model = models.efficientnet_b0(weights=weights)
202
+ if freeze_backbone:
203
+ for param in model.parameters():
204
+ param.requires_grad = False
205
+ in_features = model.classifier[-1].in_features
206
+ model.classifier[-1] = nn.Linear(in_features, num_classes)
207
+ else:
208
+ raise ValueError(f"Unsupported model: {model_name}")
209
+
210
+ return model
211
+
212
+
213
+ def load_food101_splits(data_dir: Path, val_split: float, seed: int) -> Tuple[List[Sample], List[Sample], List[Sample], Dict[int, str]]:
214
+ """Loads Food-101 metadata and returns train/val/test splits."""
215
+
216
+ images_dir = data_dir / "images"
217
+ meta_dir = data_dir / "meta"
218
+ classes = _read_classes(meta_dir / "classes.txt")
219
+ class_to_idx = {name: idx for idx, name in enumerate(classes)}
220
+
221
+ with (meta_dir / "train.json").open() as f:
222
+ train_meta = json.load(f)
223
+ with (meta_dir / "test.json").open() as f:
224
+ test_meta = json.load(f)
225
+
226
+ rng = random.Random(seed)
227
+
228
+ train_samples: List[Sample] = []
229
+ val_samples: List[Sample] = []
230
+ for cls_name, items in train_meta.items():
231
+ paths = list(items)
232
+ rng.shuffle(paths)
233
+ val_count = max(1, int(len(paths) * val_split))
234
+ val_subset = paths[:val_count]
235
+ train_subset = paths[val_count:]
236
+ train_samples.extend(_build_samples(train_subset, images_dir, class_to_idx[cls_name]))
237
+ val_samples.extend(_build_samples(val_subset, images_dir, class_to_idx[cls_name]))
238
+
239
+ test_samples = []
240
+ for cls_name, items in test_meta.items():
241
+ test_samples.extend(_build_samples(items, images_dir, class_to_idx[cls_name]))
242
+
243
+ idx_to_class = {idx: name for name, idx in class_to_idx.items()}
244
+ return train_samples, val_samples, test_samples, idx_to_class
245
+
246
+
247
+ def _build_samples(items: Iterable[str], images_dir: Path, label: int) -> List[Sample]:
248
+ """Creates Sample objects from relative paths and a target label."""
249
+
250
+ samples = []
251
+ for item in items:
252
+ path = images_dir / f"{item}.jpg"
253
+ samples.append(Sample(path=path, label=label))
254
+ return samples
255
+
256
+
257
+ def _read_classes(path: Path) -> List[str]:
258
+ """Reads the list of Food-101 class names from disk."""
259
+
260
+ with path.open() as handle:
261
+ return [line.strip() for line in handle if line.strip()]
262
+
263
+
264
+ def create_dataloaders(
265
+ train_samples: Sequence[Sample],
266
+ val_samples: Sequence[Sample],
267
+ test_samples: Sequence[Sample],
268
+ train_transform: A.BasicTransform,
269
+ eval_transform: A.BasicTransform,
270
+ batch_size: int,
271
+ num_workers: int,
272
+ ) -> Tuple[DataLoader, DataLoader, DataLoader]:
273
+ """Wraps datasets with PyTorch DataLoaders."""
274
+
275
+ pin_memory = torch.cuda.is_available()
276
+ train_dataset = Food101Dataset(train_samples, transform=train_transform)
277
+ val_dataset = Food101Dataset(val_samples, transform=eval_transform)
278
+ test_dataset = Food101Dataset(test_samples, transform=eval_transform)
279
+
280
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
281
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
282
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
283
+ return train_loader, val_loader, test_loader
284
+
285
+
286
+ def write_metrics(log_path: Path, rows: Sequence[Dict[str, object]]) -> None:
287
+ """Appends metric rows to the CSV log."""
288
+
289
+ log_path.parent.mkdir(parents=True, exist_ok=True)
290
+ file_exists = log_path.exists()
291
+ fieldnames = ["model", "experiment", "epoch", "split", "loss", "accuracy"]
292
+ with log_path.open("a", newline="") as csvfile:
293
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
294
+ if not file_exists:
295
+ writer.writeheader()
296
+ writer.writerows(rows)
297
+
298
+
299
+ def maybe_init_wandb(args: argparse.Namespace, config_extra: Dict[str, object]) -> Optional[object]:
300
+ """Initializes Weights & Biases run if requested and available."""
301
+
302
+ if not args.use_wandb:
303
+ return None
304
+ if args.wandb_project is None:
305
+ print("[wandb] Project not specified; skipping W&B logging.")
306
+ return None
307
+ if wandb is None:
308
+ print("[wandb] Package not installed; skipping W&B logging.")
309
+ return None
310
+
311
+ run_name = args.wandb_run_name or args.experiment_name or args.model
312
+ config = {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}
313
+ config.update(config_extra)
314
+ run = wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=run_name, config=config)
315
+ return run
316
+
317
+
318
+ def train_one_epoch(
319
+ model: nn.Module,
320
+ dataloader: DataLoader,
321
+ criterion: nn.Module,
322
+ optimizer: torch.optim.Optimizer,
323
+ device: torch.device,
324
+ scaler: GradScaler,
325
+ grad_clip_norm: Optional[float],
326
+ amp_enabled: bool,
327
+ clip_params: Sequence[torch.nn.Parameter],
328
+ amp_device_type: str,
329
+ ) -> Tuple[float, float]:
330
+ """Runs one training epoch and returns loss and accuracy."""
331
+
332
+ model.train()
333
+ running_loss = 0.0
334
+ correct = 0
335
+ total = 0
336
+
337
+ for inputs, targets in dataloader:
338
+ inputs = inputs.to(device)
339
+ targets = targets.to(device)
340
+
341
+ # Forward + backward pass and optimizer update.
342
+ optimizer.zero_grad()
343
+ with autocast(amp_device_type, enabled=amp_enabled):
344
+ outputs = model(inputs)
345
+ loss = criterion(outputs, targets)
346
+
347
+ if scaler.is_enabled():
348
+ scaler.scale(loss).backward()
349
+ if grad_clip_norm is not None:
350
+ scaler.unscale_(optimizer)
351
+ torch.nn.utils.clip_grad_norm_(clip_params, grad_clip_norm)
352
+ scaler.step(optimizer)
353
+ scaler.update()
354
+ else:
355
+ loss.backward()
356
+ if grad_clip_norm is not None:
357
+ torch.nn.utils.clip_grad_norm_(clip_params, grad_clip_norm)
358
+ optimizer.step()
359
+
360
+ running_loss += loss.item() * inputs.size(0)
361
+ preds = outputs.argmax(dim=1)
362
+ correct += (preds == targets).sum().item()
363
+ total += targets.size(0)
364
+
365
+ epoch_loss = running_loss / max(total, 1)
366
+ epoch_acc = correct / max(total, 1)
367
+ return epoch_loss, epoch_acc
368
+
369
+
370
+ def evaluate(
371
+ model: nn.Module,
372
+ dataloader: DataLoader,
373
+ criterion: nn.Module,
374
+ device: torch.device,
375
+ amp_enabled: bool,
376
+ amp_device_type: str,
377
+ ) -> Tuple[float, float]:
378
+ """Evaluates the model and returns loss and accuracy."""
379
+
380
+ model.eval()
381
+ running_loss = 0.0
382
+ correct = 0
383
+ total = 0
384
+
385
+ with torch.no_grad():
386
+ for inputs, targets in dataloader:
387
+ inputs = inputs.to(device)
388
+ targets = targets.to(device)
389
+ with autocast(amp_device_type, enabled=amp_enabled):
390
+ outputs = model(inputs)
391
+ loss = criterion(outputs, targets)
392
+ running_loss += loss.item() * inputs.size(0)
393
+ preds = outputs.argmax(dim=1)
394
+ correct += (preds == targets).sum().item()
395
+ total += targets.size(0)
396
+
397
+ epoch_loss = running_loss / max(total, 1)
398
+ epoch_acc = correct / max(total, 1)
399
+ return epoch_loss, epoch_acc
400
+
401
+
402
+ def main() -> None:
403
+ args = parse_args()
404
+ set_seed(args.seed)
405
+
406
+ data_dir = args.data_dir.expanduser().resolve()
407
+ if not data_dir.exists():
408
+ raise FileNotFoundError(f"Dataset directory not found: {data_dir}")
409
+
410
+ # Build samples for each split based on Food-101 metadata.
411
+ train_samples, val_samples, test_samples, idx_to_class = load_food101_splits(data_dir, args.val_split, args.seed)
412
+ num_classes = len(idx_to_class)
413
+
414
+ # Prepare augmentations and dataloaders.
415
+ train_transform, eval_transform = build_transforms(args.image_size)
416
+ train_loader, val_loader, test_loader = create_dataloaders(
417
+ train_samples,
418
+ val_samples,
419
+ test_samples,
420
+ train_transform,
421
+ eval_transform,
422
+ args.batch_size,
423
+ args.num_workers,
424
+ )
425
+
426
+ # Initialize model, loss, and optimizer on the desired device.
427
+ device = torch.device(args.device)
428
+ model = build_model(
429
+ model_name=args.model,
430
+ num_classes=num_classes,
431
+ pretrained=args.pretrained,
432
+ freeze_backbone=args.freeze_backbone,
433
+ ).to(device)
434
+ criterion = nn.CrossEntropyLoss()
435
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
436
+ if not trainable_params:
437
+ trainable_params = list(model.parameters())
438
+ optimizer = Adam(trainable_params, lr=args.learning_rate)
439
+
440
+ amp_enabled = args.use_amp and device.type == "cuda"
441
+ amp_device_type = "cuda" if device.type == "cuda" else "cpu"
442
+ scaler = GradScaler(amp_device_type, enabled=amp_enabled)
443
+ clip_params: Sequence[torch.nn.Parameter] = trainable_params
444
+
445
+ checkpoint_dir = args.checkpoint_dir
446
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
447
+
448
+ experiment_name = args.experiment_name or args.model
449
+ checkpoint_path = checkpoint_dir / f"{experiment_name}.pth"
450
+ log_path = args.log_dir / f"{experiment_name}.csv"
451
+
452
+ monitor_max = args.early_stop_metric == "accuracy"
453
+ best_metric = -float("inf") if monitor_max else float("inf")
454
+ patience_counter = 0
455
+ best_val_acc = 0.0
456
+ epochs_completed = 0
457
+
458
+ wandb_run = maybe_init_wandb(
459
+ args,
460
+ {
461
+ "num_classes": num_classes,
462
+ "train_samples": len(train_samples),
463
+ "val_samples": len(val_samples),
464
+ "test_samples": len(test_samples),
465
+ "amp_enabled": amp_enabled,
466
+ },
467
+ )
468
+
469
+ # Standard training loop with validation monitoring.
470
+ for epoch in range(1, args.epochs + 1):
471
+ train_loss, train_acc = train_one_epoch(
472
+ model,
473
+ train_loader,
474
+ criterion,
475
+ optimizer,
476
+ device,
477
+ scaler,
478
+ args.grad_clip_norm,
479
+ amp_enabled,
480
+ clip_params,
481
+ amp_device_type,
482
+ )
483
+ val_loss, val_acc = evaluate(model, val_loader, criterion, device, amp_enabled, amp_device_type)
484
+ print(
485
+ f"Epoch {epoch:02d}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
486
+ f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}"
487
+ )
488
+ write_metrics(
489
+ log_path,
490
+ [
491
+ {
492
+ "model": args.model,
493
+ "experiment": experiment_name,
494
+ "epoch": epoch,
495
+ "split": "train",
496
+ "loss": train_loss,
497
+ "accuracy": train_acc,
498
+ },
499
+ {
500
+ "model": args.model,
501
+ "experiment": experiment_name,
502
+ "epoch": epoch,
503
+ "split": "val",
504
+ "loss": val_loss,
505
+ "accuracy": val_acc,
506
+ },
507
+ ],
508
+ )
509
+ if wandb_run is not None:
510
+ wandb_run.log(
511
+ {
512
+ "epoch": epoch,
513
+ "train/loss": train_loss,
514
+ "train/accuracy": train_acc,
515
+ "val/loss": val_loss,
516
+ "val/accuracy": val_acc,
517
+ },
518
+ step=epoch,
519
+ )
520
+
521
+ if val_acc > best_val_acc:
522
+ best_val_acc = val_acc
523
+
524
+ monitor_value = val_acc if monitor_max else val_loss
525
+ improved = (
526
+ monitor_value > best_metric + args.early_stop_min_delta
527
+ if monitor_max
528
+ else monitor_value < best_metric - args.early_stop_min_delta
529
+ )
530
+ if improved:
531
+ best_metric = monitor_value
532
+ patience_counter = 0
533
+ torch.save(model.state_dict(), checkpoint_path)
534
+ else:
535
+ patience_counter += 1
536
+ if patience_counter >= args.early_stop_patience:
537
+ print(
538
+ f"Early stopping triggered at epoch {epoch}. "
539
+ f"Best {args.early_stop_metric}: {best_metric:.4f}"
540
+ )
541
+ epochs_completed = epoch
542
+ break
543
+
544
+ epochs_completed = epoch
545
+ else:
546
+ epochs_completed = args.epochs
547
+
548
+ # Ensure we persist weights even if validation never improved.
549
+ if not checkpoint_path.exists():
550
+ torch.save(model.state_dict(), checkpoint_path)
551
+
552
+ # Load best checkpoint before final evaluation.
553
+ try:
554
+ state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
555
+ except TypeError:
556
+ state_dict = torch.load(checkpoint_path, map_location=device)
557
+ model.load_state_dict(state_dict)
558
+
559
+ # Final evaluation on the held-out test set.
560
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, amp_enabled, amp_device_type)
561
+ print(f"Test metrics - loss: {test_loss:.4f} accuracy: {test_acc:.4f}")
562
+ write_metrics(
563
+ log_path,
564
+ [
565
+ {
566
+ "model": args.model,
567
+ "experiment": experiment_name,
568
+ "epoch": epochs_completed,
569
+ "split": "test",
570
+ "loss": test_loss,
571
+ "accuracy": test_acc,
572
+ }
573
+ ],
574
+ )
575
+
576
+ if wandb_run is not None:
577
+ wandb_run.log(
578
+ {
579
+ "test/loss": test_loss,
580
+ "test/accuracy": test_acc,
581
+ "best/val_accuracy": best_val_acc,
582
+ f"best/val_{args.early_stop_metric}": best_metric,
583
+ },
584
+ step=epochs_completed,
585
+ )
586
+ wandb_run.finish()
587
+
588
+
589
+ if __name__ == "__main__":
590
+ main()