Spaces:
Sleeping
Sleeping
LeooNic
commited on
Commit
·
b29fbac
1
Parent(s):
b6c74eb
Deploy Food-101 classifier with 84.49% accuracy
Browse files- EfficientNet-B0 model achieving 84.49% test accuracy
- ONNX optimized for 7ms inference time
- Interactive Gradio interface with example images
- Supports 101 food classes from Food-101 dataset
- Ready for production deployment
- README.md +114 -6
- app.py +41 -0
- food-101/food-101/images/hamburger/100057.jpg +0 -0
- food-101/food-101/images/ice_cream/1004744.jpg +0 -0
- food-101/food-101/images/pizza/1001116.jpg +0 -0
- food-101/food-101/meta/classes.txt +101 -0
- food-101/food-101/meta/labels.txt +101 -0
- food-101/food-101/meta/test.json +0 -0
- food-101/food-101/meta/test.txt +0 -0
- food-101/food-101/meta/train.json +0 -0
- food-101/food-101/meta/train.txt +0 -0
- gradio_app/app.py +225 -0
- models/efficientnet_b0_food101.onnx +3 -0
- requirements.txt +23 -0
- scripts/__pycache__/check_gpu.cpython-312.pyc +0 -0
- scripts/__pycache__/evaluate.cpython-312.pyc +0 -0
- scripts/__pycache__/gradcam.cpython-312.pyc +0 -0
- scripts/__pycache__/predict.cpython-312.pyc +0 -0
- scripts/__pycache__/train.cpython-312.pyc +0 -0
- scripts/check_gpu.py +21 -0
- scripts/evaluate.py +136 -0
- scripts/export_model.py +159 -0
- scripts/gradcam.py +143 -0
- scripts/predict.py +175 -0
- scripts/train.py +590 -0
README.md
CHANGED
@@ -1,14 +1,122 @@
|
|
1 |
---
|
2 |
-
title: Food
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
short_description: My Space
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Food-101 AI Classifier
|
3 |
+
emoji: 🍔
|
4 |
+
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
+
# 🍔 Food-101 AI Classifier
|
14 |
+
|
15 |
+
An AI-powered food image classifier trained on the Food-101 dataset, capable of recognizing 101 different types of food with high accuracy.
|
16 |
+
|
17 |
+
## 🎯 Model Performance
|
18 |
+
|
19 |
+
- **Architecture**: EfficientNet-B0 (fine-tuned)
|
20 |
+
- **Test Accuracy**: 84.49%
|
21 |
+
- **Top-5 Accuracy**: 96.72%
|
22 |
+
- **Inference Speed**: ~7ms per image
|
23 |
+
- **Model Size**: 15.77 MB (ONNX optimized)
|
24 |
+
|
25 |
+
## 🚀 Features
|
26 |
+
|
27 |
+
- **Fast Inference**: ONNX-optimized model for lightning-fast predictions
|
28 |
+
- **High Accuracy**: State-of-the-art performance on Food-101 dataset
|
29 |
+
- **User-Friendly Interface**: Clean and intuitive Gradio web interface
|
30 |
+
- **Real-time Predictions**: Upload any food image and get instant results
|
31 |
+
- **Confidence Scores**: See how confident the model is about each prediction
|
32 |
+
|
33 |
+
## 📊 Dataset
|
34 |
+
|
35 |
+
This model was trained on the [Food-101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/), which contains:
|
36 |
+
- **101 food categories**
|
37 |
+
- **1,000 images per category**
|
38 |
+
- **Total**: 101,000 images
|
39 |
+
|
40 |
+
## 🏆 Recognized Food Categories
|
41 |
+
|
42 |
+
The model can classify the following 101 food types:
|
43 |
+
|
44 |
+
```
|
45 |
+
apple_pie, baby_back_ribs, baklava, beef_carpaccio, beef_tartare, beet_salad,
|
46 |
+
beignets, bibimbap, bread_pudding, breakfast_burrito, bruschetta, caesar_salad,
|
47 |
+
cannoli, caprese_salad, carrot_cake, ceviche, cheese_plate, cheesecake,
|
48 |
+
chicken_curry, chicken_quesadilla, chicken_wings, chocolate_cake, chocolate_mousse,
|
49 |
+
churros, clam_chowder, club_sandwich, crab_cakes, creme_brulee, croque_madame,
|
50 |
+
cup_cakes, deviled_eggs, donuts, dumplings, edamame, eggs_benedict, escargots,
|
51 |
+
falafel, filet_mignon, fish_and_chips, foie_gras, french_fries, french_onion_soup,
|
52 |
+
french_toast, fried_calamari, fried_rice, frozen_yogurt, garlic_bread, gnocchi,
|
53 |
+
greek_salad, grilled_cheese_sandwich, grilled_salmon, guacamole, gyoza, hamburger,
|
54 |
+
hot_and_sour_soup, hot_dog, huevos_rancheros, hummus, ice_cream, lasagna,
|
55 |
+
lobster_bisque, lobster_roll_sandwich, macaroni_and_cheese, macarons, miso_soup,
|
56 |
+
mussels, nachos, omelette, onion_rings, oysters, pad_thai, paella, pancakes,
|
57 |
+
panna_cotta, peking_duck, pho, pizza, pork_chop, poutine, prime_rib, pulled_pork_sandwich,
|
58 |
+
ramen, ravioli, red_velvet_cake, risotto, samosa, sashimi, scallops, seaweed_salad,
|
59 |
+
shrimp_and_grits, spaghetti_bolognese, spaghetti_carbonara, spring_rolls, steak,
|
60 |
+
strawberry_shortcake, sushi, tacos, takoyaki, tiramisu, tuna_tartare, waffles
|
61 |
+
```
|
62 |
+
|
63 |
+
## 🛠️ Technical Details
|
64 |
+
|
65 |
+
### Architecture
|
66 |
+
- **Base Model**: EfficientNet-B0 (pre-trained on ImageNet)
|
67 |
+
- **Fine-tuning**: Transfer learning with Food-101 dataset
|
68 |
+
- **Optimization**: ONNX Runtime for fast inference
|
69 |
+
- **Input Size**: 224×224×3 RGB images
|
70 |
+
|
71 |
+
### Training Pipeline
|
72 |
+
1. **Data Augmentation**: Albumentations library for robust training
|
73 |
+
2. **Transfer Learning**: Fine-tuned pre-trained EfficientNet-B0
|
74 |
+
3. **Advanced Training**: Early stopping, gradient clipping, AMP
|
75 |
+
4. **Validation**: 10% held-out validation set for model selection
|
76 |
+
|
77 |
+
### Deployment Stack
|
78 |
+
- **Model Format**: ONNX (optimized for inference)
|
79 |
+
- **Backend**: Python with ONNX Runtime
|
80 |
+
- **Frontend**: Gradio web interface
|
81 |
+
- **Hosting**: Hugging Face Spaces
|
82 |
+
|
83 |
+
## 📝 How to Use
|
84 |
+
|
85 |
+
1. **Upload an Image**: Click on the upload area or drag & drop a food image
|
86 |
+
2. **Set Predictions**: Choose how many top predictions you want (1-10)
|
87 |
+
3. **Get Results**: Click "Submit" to see predictions with confidence scores
|
88 |
+
4. **Try Examples**: Use the provided example images to test the model
|
89 |
+
|
90 |
+
## 🧪 Model Evaluation
|
91 |
+
|
92 |
+
### Performance Metrics
|
93 |
+
- **Accuracy**: 84.49%
|
94 |
+
- **Macro F1-Score**: 84.40%
|
95 |
+
- **Weighted F1-Score**: 84.40%
|
96 |
+
- **Top-5 Accuracy**: 96.72%
|
97 |
+
|
98 |
+
### Most Challenging Classes
|
99 |
+
The model struggles most with:
|
100 |
+
1. Steak (51.35% F1-score)
|
101 |
+
2. Apple Pie (63.36% F1-score)
|
102 |
+
3. Pork Chop (66.02% F1-score)
|
103 |
+
|
104 |
+
## 🔬 Explainability
|
105 |
+
|
106 |
+
The model includes Grad-CAM visualization capabilities to show which parts of the image the AI focuses on when making predictions, providing transparency into the decision-making process.
|
107 |
+
|
108 |
+
## 📜 License
|
109 |
+
|
110 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
111 |
+
|
112 |
+
## 🙏 Acknowledgments
|
113 |
+
|
114 |
+
- **Dataset**: Food-101 dataset by Bossard et al.
|
115 |
+
- **Framework**: PyTorch and torchvision
|
116 |
+
- **Optimization**: ONNX Runtime
|
117 |
+
- **Interface**: Gradio
|
118 |
+
- **Hosting**: Hugging Face Spaces
|
119 |
+
|
120 |
+
---
|
121 |
+
|
122 |
+
Built with ❤️ using PyTorch, ONNX, and Gradio
|
app.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Hugging Face Spaces deployment app for Food-101 classification."""
|
2 |
+
|
3 |
+
# This is the main app file expected by Hugging Face Spaces
|
4 |
+
# It imports and runs the Gradio app from gradio_app/app.py
|
5 |
+
|
6 |
+
import sys
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Add current directory to path
|
10 |
+
sys.path.append(str(Path(__file__).parent))
|
11 |
+
|
12 |
+
# Import the Gradio app
|
13 |
+
from gradio_app.app import GradioFood101App
|
14 |
+
|
15 |
+
def main():
|
16 |
+
"""Main function for Hugging Face Spaces deployment."""
|
17 |
+
try:
|
18 |
+
# Initialize the app
|
19 |
+
print("[HF SPACES] Initializing Food-101 Classifier App...")
|
20 |
+
app = GradioFood101App()
|
21 |
+
|
22 |
+
# Create interface
|
23 |
+
print("[HF SPACES] Creating Gradio interface...")
|
24 |
+
interface = app.create_interface()
|
25 |
+
|
26 |
+
# Launch the app for HF Spaces
|
27 |
+
print("[HF SPACES] Launching app for Hugging Face Spaces...")
|
28 |
+
interface.launch(
|
29 |
+
share=False,
|
30 |
+
server_name="0.0.0.0",
|
31 |
+
server_port=7860,
|
32 |
+
show_error=True,
|
33 |
+
enable_queue=True
|
34 |
+
)
|
35 |
+
|
36 |
+
except Exception as e:
|
37 |
+
print(f"[ERROR] Failed to launch HF Spaces app: {e}")
|
38 |
+
raise
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
food-101/food-101/images/hamburger/100057.jpg
ADDED
![]() |
food-101/food-101/images/ice_cream/1004744.jpg
ADDED
![]() |
food-101/food-101/images/pizza/1001116.jpg
ADDED
![]() |
food-101/food-101/meta/classes.txt
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apple_pie
|
2 |
+
baby_back_ribs
|
3 |
+
baklava
|
4 |
+
beef_carpaccio
|
5 |
+
beef_tartare
|
6 |
+
beet_salad
|
7 |
+
beignets
|
8 |
+
bibimbap
|
9 |
+
bread_pudding
|
10 |
+
breakfast_burrito
|
11 |
+
bruschetta
|
12 |
+
caesar_salad
|
13 |
+
cannoli
|
14 |
+
caprese_salad
|
15 |
+
carrot_cake
|
16 |
+
ceviche
|
17 |
+
cheesecake
|
18 |
+
cheese_plate
|
19 |
+
chicken_curry
|
20 |
+
chicken_quesadilla
|
21 |
+
chicken_wings
|
22 |
+
chocolate_cake
|
23 |
+
chocolate_mousse
|
24 |
+
churros
|
25 |
+
clam_chowder
|
26 |
+
club_sandwich
|
27 |
+
crab_cakes
|
28 |
+
creme_brulee
|
29 |
+
croque_madame
|
30 |
+
cup_cakes
|
31 |
+
deviled_eggs
|
32 |
+
donuts
|
33 |
+
dumplings
|
34 |
+
edamame
|
35 |
+
eggs_benedict
|
36 |
+
escargots
|
37 |
+
falafel
|
38 |
+
filet_mignon
|
39 |
+
fish_and_chips
|
40 |
+
foie_gras
|
41 |
+
french_fries
|
42 |
+
french_onion_soup
|
43 |
+
french_toast
|
44 |
+
fried_calamari
|
45 |
+
fried_rice
|
46 |
+
frozen_yogurt
|
47 |
+
garlic_bread
|
48 |
+
gnocchi
|
49 |
+
greek_salad
|
50 |
+
grilled_cheese_sandwich
|
51 |
+
grilled_salmon
|
52 |
+
guacamole
|
53 |
+
gyoza
|
54 |
+
hamburger
|
55 |
+
hot_and_sour_soup
|
56 |
+
hot_dog
|
57 |
+
huevos_rancheros
|
58 |
+
hummus
|
59 |
+
ice_cream
|
60 |
+
lasagna
|
61 |
+
lobster_bisque
|
62 |
+
lobster_roll_sandwich
|
63 |
+
macaroni_and_cheese
|
64 |
+
macarons
|
65 |
+
miso_soup
|
66 |
+
mussels
|
67 |
+
nachos
|
68 |
+
omelette
|
69 |
+
onion_rings
|
70 |
+
oysters
|
71 |
+
pad_thai
|
72 |
+
paella
|
73 |
+
pancakes
|
74 |
+
panna_cotta
|
75 |
+
peking_duck
|
76 |
+
pho
|
77 |
+
pizza
|
78 |
+
pork_chop
|
79 |
+
poutine
|
80 |
+
prime_rib
|
81 |
+
pulled_pork_sandwich
|
82 |
+
ramen
|
83 |
+
ravioli
|
84 |
+
red_velvet_cake
|
85 |
+
risotto
|
86 |
+
samosa
|
87 |
+
sashimi
|
88 |
+
scallops
|
89 |
+
seaweed_salad
|
90 |
+
shrimp_and_grits
|
91 |
+
spaghetti_bolognese
|
92 |
+
spaghetti_carbonara
|
93 |
+
spring_rolls
|
94 |
+
steak
|
95 |
+
strawberry_shortcake
|
96 |
+
sushi
|
97 |
+
tacos
|
98 |
+
takoyaki
|
99 |
+
tiramisu
|
100 |
+
tuna_tartare
|
101 |
+
waffles
|
food-101/food-101/meta/labels.txt
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apple pie
|
2 |
+
Baby back ribs
|
3 |
+
Baklava
|
4 |
+
Beef carpaccio
|
5 |
+
Beef tartare
|
6 |
+
Beet salad
|
7 |
+
Beignets
|
8 |
+
Bibimbap
|
9 |
+
Bread pudding
|
10 |
+
Breakfast burrito
|
11 |
+
Bruschetta
|
12 |
+
Caesar salad
|
13 |
+
Cannoli
|
14 |
+
Caprese salad
|
15 |
+
Carrot cake
|
16 |
+
Ceviche
|
17 |
+
Cheesecake
|
18 |
+
Cheese plate
|
19 |
+
Chicken curry
|
20 |
+
Chicken quesadilla
|
21 |
+
Chicken wings
|
22 |
+
Chocolate cake
|
23 |
+
Chocolate mousse
|
24 |
+
Churros
|
25 |
+
Clam chowder
|
26 |
+
Club sandwich
|
27 |
+
Crab cakes
|
28 |
+
Creme brulee
|
29 |
+
Croque madame
|
30 |
+
Cup cakes
|
31 |
+
Deviled eggs
|
32 |
+
Donuts
|
33 |
+
Dumplings
|
34 |
+
Edamame
|
35 |
+
Eggs benedict
|
36 |
+
Escargots
|
37 |
+
Falafel
|
38 |
+
Filet mignon
|
39 |
+
Fish and chips
|
40 |
+
Foie gras
|
41 |
+
French fries
|
42 |
+
French onion soup
|
43 |
+
French toast
|
44 |
+
Fried calamari
|
45 |
+
Fried rice
|
46 |
+
Frozen yogurt
|
47 |
+
Garlic bread
|
48 |
+
Gnocchi
|
49 |
+
Greek salad
|
50 |
+
Grilled cheese sandwich
|
51 |
+
Grilled salmon
|
52 |
+
Guacamole
|
53 |
+
Gyoza
|
54 |
+
Hamburger
|
55 |
+
Hot and sour soup
|
56 |
+
Hot dog
|
57 |
+
Huevos rancheros
|
58 |
+
Hummus
|
59 |
+
Ice cream
|
60 |
+
Lasagna
|
61 |
+
Lobster bisque
|
62 |
+
Lobster roll sandwich
|
63 |
+
Macaroni and cheese
|
64 |
+
Macarons
|
65 |
+
Miso soup
|
66 |
+
Mussels
|
67 |
+
Nachos
|
68 |
+
Omelette
|
69 |
+
Onion rings
|
70 |
+
Oysters
|
71 |
+
Pad thai
|
72 |
+
Paella
|
73 |
+
Pancakes
|
74 |
+
Panna cotta
|
75 |
+
Peking duck
|
76 |
+
Pho
|
77 |
+
Pizza
|
78 |
+
Pork chop
|
79 |
+
Poutine
|
80 |
+
Prime rib
|
81 |
+
Pulled pork sandwich
|
82 |
+
Ramen
|
83 |
+
Ravioli
|
84 |
+
Red velvet cake
|
85 |
+
Risotto
|
86 |
+
Samosa
|
87 |
+
Sashimi
|
88 |
+
Scallops
|
89 |
+
Seaweed salad
|
90 |
+
Shrimp and grits
|
91 |
+
Spaghetti bolognese
|
92 |
+
Spaghetti carbonara
|
93 |
+
Spring rolls
|
94 |
+
Steak
|
95 |
+
Strawberry shortcake
|
96 |
+
Sushi
|
97 |
+
Tacos
|
98 |
+
Takoyaki
|
99 |
+
Tiramisu
|
100 |
+
Tuna tartare
|
101 |
+
Waffles
|
food-101/food-101/meta/test.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
food-101/food-101/meta/test.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
food-101/food-101/meta/train.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
food-101/food-101/meta/train.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
gradio_app/app.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gradio demo app for Food-101 classification."""
|
2 |
+
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Tuple, Dict, List
|
6 |
+
import time
|
7 |
+
import tempfile
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
# Add scripts directory to path
|
14 |
+
project_root = Path(__file__).parent.parent
|
15 |
+
sys.path.append(str(project_root / "scripts"))
|
16 |
+
|
17 |
+
from predict import Food101Predictor
|
18 |
+
from train import load_food101_splits
|
19 |
+
|
20 |
+
|
21 |
+
class GradioFood101App:
|
22 |
+
"""Gradio application for Food-101 classification."""
|
23 |
+
|
24 |
+
def __init__(self):
|
25 |
+
"""Initialize the Gradio app with the ONNX predictor."""
|
26 |
+
self.predictor = None
|
27 |
+
self.load_model()
|
28 |
+
|
29 |
+
def load_model(self):
|
30 |
+
"""Load the ONNX predictor."""
|
31 |
+
try:
|
32 |
+
# Paths
|
33 |
+
model_path = project_root / "models/efficientnet_b0_food101.onnx"
|
34 |
+
data_dir = project_root / "food-101/food-101"
|
35 |
+
|
36 |
+
# Load class names
|
37 |
+
_, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=42)
|
38 |
+
class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
|
39 |
+
|
40 |
+
# Initialize predictor
|
41 |
+
self.predictor = Food101Predictor(model_path, class_names)
|
42 |
+
print(f"[GRADIO] Model loaded successfully with {len(class_names)} classes")
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
print(f"[ERROR] Failed to load model: {e}")
|
46 |
+
raise
|
47 |
+
|
48 |
+
def predict_image(self, image: Image.Image, top_k: int = 5) -> Tuple[Dict, str]:
|
49 |
+
"""
|
50 |
+
Predict food class for uploaded image.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
image: PIL Image
|
54 |
+
top_k: Number of top predictions
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
(confidences_dict, info_text)
|
58 |
+
"""
|
59 |
+
if image is None:
|
60 |
+
return {}, "Please upload an image first!"
|
61 |
+
|
62 |
+
if self.predictor is None:
|
63 |
+
return {}, "Model not loaded. Please try again."
|
64 |
+
|
65 |
+
try:
|
66 |
+
# Save image temporarily
|
67 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
|
68 |
+
image.save(tmp_file.name)
|
69 |
+
temp_path = Path(tmp_file.name)
|
70 |
+
|
71 |
+
# Run prediction
|
72 |
+
start_time = time.time()
|
73 |
+
predictions, probabilities, inference_time = self.predictor.predict(temp_path, top_k)
|
74 |
+
total_time = (time.time() - start_time) * 1000
|
75 |
+
|
76 |
+
# Clean up
|
77 |
+
temp_path.unlink(missing_ok=True)
|
78 |
+
|
79 |
+
# Format results for Gradio
|
80 |
+
confidences = {}
|
81 |
+
for pred, prob in zip(predictions, probabilities):
|
82 |
+
confidences[pred.replace('_', ' ').title()] = float(prob)
|
83 |
+
|
84 |
+
# Create info text
|
85 |
+
info_lines = [
|
86 |
+
f"🔍 **Prediction Results**",
|
87 |
+
f"⚡ **Inference Time**: {inference_time:.2f}ms",
|
88 |
+
f"🕒 **Total Time**: {total_time:.2f}ms",
|
89 |
+
f"🧠 **Model**: EfficientNet-B0 (ONNX)",
|
90 |
+
f"📊 **Top Prediction**: {predictions[0].replace('_', ' ').title()} ({probabilities[0]*100:.1f}%)"
|
91 |
+
]
|
92 |
+
|
93 |
+
info_text = "\n".join(info_lines)
|
94 |
+
|
95 |
+
return confidences, info_text
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
temp_path.unlink(missing_ok=True)
|
99 |
+
return {}, f"❌ **Error**: {str(e)}"
|
100 |
+
|
101 |
+
def get_examples(self) -> List[List]:
|
102 |
+
"""Get example images for the demo."""
|
103 |
+
examples_dir = project_root / "food-101/food-101/images"
|
104 |
+
examples = []
|
105 |
+
|
106 |
+
# Select a few example images from different classes
|
107 |
+
example_classes = ['pizza', 'hamburger', 'ice_cream']
|
108 |
+
|
109 |
+
for class_name in example_classes:
|
110 |
+
class_dir = examples_dir / class_name
|
111 |
+
if class_dir.exists():
|
112 |
+
# Get first image from class
|
113 |
+
images = list(class_dir.glob("*.jpg"))
|
114 |
+
if images:
|
115 |
+
# Format: [image_path, top_k_value]
|
116 |
+
examples.append([str(images[0]), 5])
|
117 |
+
|
118 |
+
# If no examples found, return empty list (Gradio will handle gracefully)
|
119 |
+
return examples if examples else []
|
120 |
+
|
121 |
+
def create_interface(self) -> gr.Interface:
|
122 |
+
"""Create and return the Gradio interface."""
|
123 |
+
|
124 |
+
# Custom CSS for better styling
|
125 |
+
css = """
|
126 |
+
.main-header {
|
127 |
+
text-align: center;
|
128 |
+
background: linear-gradient(90deg, #ff6b6b, #4ecdc4);
|
129 |
+
-webkit-background-clip: text;
|
130 |
+
-webkit-text-fill-color: transparent;
|
131 |
+
font-size: 2.5em;
|
132 |
+
font-weight: bold;
|
133 |
+
margin-bottom: 20px;
|
134 |
+
}
|
135 |
+
.info-box {
|
136 |
+
background-color: #f0f8ff;
|
137 |
+
border-left: 5px solid #4ecdc4;
|
138 |
+
padding: 15px;
|
139 |
+
margin: 10px 0;
|
140 |
+
border-radius: 5px;
|
141 |
+
}
|
142 |
+
"""
|
143 |
+
|
144 |
+
# Interface description
|
145 |
+
description = """
|
146 |
+
## 🍕 Food-101 Image Classifier
|
147 |
+
|
148 |
+
Upload an image of food and get AI-powered predictions! This demo uses a fine-tuned **EfficientNet-B0** model
|
149 |
+
trained on the Food-101 dataset to classify 101 different types of food.
|
150 |
+
|
151 |
+
### 🎯 **Model Performance**
|
152 |
+
- **Accuracy**: 84.49% on test set
|
153 |
+
- **Inference Speed**: ~7ms per image
|
154 |
+
- **Classes**: 101 different food types
|
155 |
+
|
156 |
+
### 🚀 **How to use**
|
157 |
+
1. Upload an image or try one of our examples
|
158 |
+
2. Adjust the number of top predictions (1-10)
|
159 |
+
3. Click Submit to get predictions with confidence scores!
|
160 |
+
"""
|
161 |
+
|
162 |
+
# Create the interface
|
163 |
+
interface = gr.Interface(
|
164 |
+
fn=self.predict_image,
|
165 |
+
inputs=[
|
166 |
+
gr.Image(
|
167 |
+
type="pil",
|
168 |
+
label="📸 Upload Food Image",
|
169 |
+
height=300
|
170 |
+
),
|
171 |
+
gr.Slider(
|
172 |
+
minimum=1,
|
173 |
+
maximum=10,
|
174 |
+
value=5,
|
175 |
+
step=1,
|
176 |
+
label="🔢 Number of Predictions"
|
177 |
+
)
|
178 |
+
],
|
179 |
+
outputs=[
|
180 |
+
gr.Label(
|
181 |
+
label="🏆 Predictions & Confidence Scores",
|
182 |
+
num_top_classes=10
|
183 |
+
),
|
184 |
+
gr.Markdown(
|
185 |
+
label="📊 Prediction Details"
|
186 |
+
)
|
187 |
+
],
|
188 |
+
title="🍔 Food-101 AI Classifier",
|
189 |
+
description=description,
|
190 |
+
examples=self.get_examples(),
|
191 |
+
css=css,
|
192 |
+
theme=gr.themes.Soft(),
|
193 |
+
flagging_mode="never"
|
194 |
+
)
|
195 |
+
|
196 |
+
return interface
|
197 |
+
|
198 |
+
|
199 |
+
def main():
|
200 |
+
"""Main function to launch the Gradio app."""
|
201 |
+
try:
|
202 |
+
# Initialize the app
|
203 |
+
print("[GRADIO] Initializing Food-101 Classifier App...")
|
204 |
+
app = GradioFood101App()
|
205 |
+
|
206 |
+
# Create interface
|
207 |
+
print("[GRADIO] Creating Gradio interface...")
|
208 |
+
interface = app.create_interface()
|
209 |
+
|
210 |
+
# Launch the app
|
211 |
+
print("[GRADIO] Launching app...")
|
212 |
+
interface.launch(
|
213 |
+
share=False, # Set to True to create public link
|
214 |
+
server_name="0.0.0.0",
|
215 |
+
server_port=7860,
|
216 |
+
show_error=True
|
217 |
+
)
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
print(f"[ERROR] Failed to launch Gradio app: {e}")
|
221 |
+
raise
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == "__main__":
|
225 |
+
main()
|
models/efficientnet_b0_food101.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a19fedaec8bb25ca7af8d49a71aaf2d5f71588bd96d859c252fd1e4902345179
|
3 |
+
size 16537732
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Requirements for Hugging Face Spaces deployment
|
2 |
+
# Core dependencies for Food-101 classifier
|
3 |
+
|
4 |
+
# Deep Learning & Computer Vision
|
5 |
+
torch>=2.0.0
|
6 |
+
torchvision>=0.15.0
|
7 |
+
onnxruntime>=1.20.0
|
8 |
+
onnx>=1.15.0
|
9 |
+
|
10 |
+
# Image Processing
|
11 |
+
pillow>=9.0.0
|
12 |
+
opencv-python>=4.5.0
|
13 |
+
albumentations>=1.3.0
|
14 |
+
|
15 |
+
# ML & Data Science
|
16 |
+
numpy>=1.24.0
|
17 |
+
scikit-learn>=1.3.0
|
18 |
+
|
19 |
+
# Web Interface
|
20 |
+
gradio>=4.0.0
|
21 |
+
|
22 |
+
# Utilities
|
23 |
+
pathlib
|
scripts/__pycache__/check_gpu.cpython-312.pyc
ADDED
Binary file (1.37 kB). View file
|
|
scripts/__pycache__/evaluate.cpython-312.pyc
ADDED
Binary file (9.39 kB). View file
|
|
scripts/__pycache__/gradcam.cpython-312.pyc
ADDED
Binary file (9.18 kB). View file
|
|
scripts/__pycache__/predict.cpython-312.pyc
ADDED
Binary file (9.5 kB). View file
|
|
scripts/__pycache__/train.cpython-312.pyc
ADDED
Binary file (28.6 kB). View file
|
|
scripts/check_gpu.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility script to report CUDA GPU availability for PyTorch."""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def main() -> None:
|
7 |
+
has_cuda = torch.cuda.is_available()
|
8 |
+
print(f"torch.cuda.is_available(): {has_cuda}")
|
9 |
+
if has_cuda:
|
10 |
+
num_devices = torch.cuda.device_count()
|
11 |
+
print(f"Detected CUDA devices: {num_devices}")
|
12 |
+
for idx in range(num_devices):
|
13 |
+
name = torch.cuda.get_device_name(idx)
|
14 |
+
capability = torch.cuda.get_device_capability(idx)
|
15 |
+
print(f" - Device {idx}: {name} (compute capability {capability[0]}.{capability[1]})")
|
16 |
+
else:
|
17 |
+
print("No CUDA-capable GPU detected. Training will fall back to CPU.")
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
scripts/evaluate.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Evaluation utilities for Food-101 classifiers (Phase 5)."""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Dict, List, Sequence
|
8 |
+
|
9 |
+
import albumentations as A
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from albumentations.pytorch import ToTensorV2
|
13 |
+
from PIL import Image
|
14 |
+
from sklearn.metrics import classification_report, confusion_matrix, top_k_accuracy_score
|
15 |
+
from torch.utils.data import DataLoader, Dataset
|
16 |
+
|
17 |
+
from train import BaselineCNN, Sample, build_model, load_food101_splits, set_seed
|
18 |
+
|
19 |
+
|
20 |
+
class Food101EvalDataset(Dataset):
|
21 |
+
"""Thin dataset wrapper that applies evaluation transforms."""
|
22 |
+
|
23 |
+
def __init__(self, samples: Sequence[Sample], transform: A.BasicTransform) -> None:
|
24 |
+
self.samples = list(samples)
|
25 |
+
self.transform = transform
|
26 |
+
|
27 |
+
def __len__(self) -> int:
|
28 |
+
return len(self.samples)
|
29 |
+
|
30 |
+
def __getitem__(self, index: int) -> Dict[str, torch.Tensor | int | str]:
|
31 |
+
sample = self.samples[index]
|
32 |
+
with sample.path.open("rb") as file:
|
33 |
+
array = np.array(Image.open(file).convert("RGB"))
|
34 |
+
tensor = self.transform(image=array)["image"]
|
35 |
+
return {"image": tensor, "label": sample.label, "path": str(sample.path)}
|
36 |
+
|
37 |
+
|
38 |
+
def parse_args() -> argparse.Namespace:
|
39 |
+
parser = argparse.ArgumentParser(description="Evaluate Food-101 checkpoints")
|
40 |
+
parser.add_argument("--data-dir", type=Path, default=Path("data/raw/food-101"), help="Dataset root")
|
41 |
+
parser.add_argument("--checkpoint", type=Path, required=True, help="Checkpoint file to evaluate")
|
42 |
+
parser.add_argument("--model", choices=["baseline", "resnet50", "efficientnet_b0"], required=True)
|
43 |
+
parser.add_argument("--split", choices=["val", "test"], default="test", help="Dataset split to use")
|
44 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
45 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
46 |
+
parser.add_argument("--image-size", type=int, default=224)
|
47 |
+
parser.add_argument("--seed", type=int, default=42)
|
48 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
parser.add_argument("--topk", nargs="*", type=int, default=[1, 5], help="Top-k accuracies to report")
|
50 |
+
parser.add_argument("--report-json", type=Path, default=None, help="Optional path to dump JSON metrics")
|
51 |
+
return parser.parse_args()
|
52 |
+
|
53 |
+
|
54 |
+
def build_eval_transform(image_size: int) -> A.BasicTransform:
|
55 |
+
return A.Compose(
|
56 |
+
[
|
57 |
+
A.Resize(height=image_size + 32, width=image_size + 32),
|
58 |
+
A.CenterCrop(height=image_size, width=image_size),
|
59 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
60 |
+
ToTensorV2(),
|
61 |
+
]
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def run_evaluation(args: argparse.Namespace) -> Dict[str, float]:
|
66 |
+
set_seed(args.seed)
|
67 |
+
device = torch.device(args.device)
|
68 |
+
|
69 |
+
data_dir = args.data_dir.expanduser().resolve()
|
70 |
+
train_samples, val_samples, test_samples, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=args.seed)
|
71 |
+
class_names: List[str] = [idx_to_class[i] for i in range(len(idx_to_class))]
|
72 |
+
|
73 |
+
split_samples = val_samples if args.split == "val" else test_samples
|
74 |
+
transform = build_eval_transform(args.image_size)
|
75 |
+
dataset = Food101EvalDataset(split_samples, transform=transform)
|
76 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=torch.cuda.is_available())
|
77 |
+
|
78 |
+
model = build_model(args.model, num_classes=len(class_names), pretrained=False, freeze_backbone=False)
|
79 |
+
try:
|
80 |
+
state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
|
81 |
+
except TypeError:
|
82 |
+
state_dict = torch.load(args.checkpoint, map_location=device)
|
83 |
+
model.load_state_dict(state_dict)
|
84 |
+
model.to(device)
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
all_probs: List[torch.Tensor] = []
|
88 |
+
all_labels: List[int] = []
|
89 |
+
|
90 |
+
with torch.no_grad():
|
91 |
+
for batch in dataloader:
|
92 |
+
inputs = batch["image"].to(device)
|
93 |
+
outputs = model(inputs)
|
94 |
+
probs = torch.softmax(outputs, dim=1)
|
95 |
+
all_probs.append(probs.cpu())
|
96 |
+
all_labels.extend(batch["label"].tolist())
|
97 |
+
|
98 |
+
probs_tensor = torch.cat(all_probs, dim=0)
|
99 |
+
preds = probs_tensor.argmax(dim=1).numpy()
|
100 |
+
labels_np = np.array(all_labels)
|
101 |
+
|
102 |
+
report = classification_report(labels_np, preds, target_names=class_names, output_dict=True, zero_division=0)
|
103 |
+
conf_mat = confusion_matrix(labels_np, preds)
|
104 |
+
|
105 |
+
metrics: Dict[str, float] = {
|
106 |
+
"accuracy": report["accuracy"],
|
107 |
+
"macro_precision": report["macro avg"]["precision"],
|
108 |
+
"macro_recall": report["macro avg"]["recall"],
|
109 |
+
"macro_f1": report["macro avg"]["f1-score"],
|
110 |
+
"weighted_f1": report["weighted avg"]["f1-score"],
|
111 |
+
}
|
112 |
+
|
113 |
+
for k in args.topk:
|
114 |
+
metrics[f"top{k}_accuracy"] = top_k_accuracy_score(labels_np, probs_tensor, k=k, labels=list(range(len(class_names))))
|
115 |
+
|
116 |
+
if args.report_json:
|
117 |
+
args.report_json.parent.mkdir(parents=True, exist_ok=True)
|
118 |
+
with args.report_json.open("w") as f:
|
119 |
+
json.dump({"metrics": metrics, "classification_report": report, "confusion_matrix": conf_mat.tolist()}, f, indent=2)
|
120 |
+
|
121 |
+
print("=== Metrics ===")
|
122 |
+
for key, value in metrics.items():
|
123 |
+
print(f"{key}: {value:.4f}")
|
124 |
+
|
125 |
+
print("=== Confusion Matrix (sample) ===")
|
126 |
+
print(conf_mat)
|
127 |
+
return metrics
|
128 |
+
|
129 |
+
|
130 |
+
def main() -> None:
|
131 |
+
args = parse_args()
|
132 |
+
run_evaluation(args)
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
main()
|
scripts/export_model.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Export PyTorch models to ONNX format for optimized inference."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import torch
|
6 |
+
import onnx
|
7 |
+
import onnxruntime as ort
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from train import build_model, load_food101_splits, set_seed
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args() -> argparse.Namespace:
|
14 |
+
parser = argparse.ArgumentParser(description="Export PyTorch model to ONNX")
|
15 |
+
parser.add_argument("--checkpoint", type=Path, required=True, help="Path to PyTorch checkpoint")
|
16 |
+
parser.add_argument("--model", choices=["baseline", "resnet50", "efficientnet_b0"], required=True)
|
17 |
+
parser.add_argument("--output", type=Path, required=True, help="Output ONNX file path")
|
18 |
+
parser.add_argument("--data-dir", type=Path, default=Path("food-101/food-101"), help="Dataset root")
|
19 |
+
parser.add_argument("--input-size", type=int, default=224, help="Input image size")
|
20 |
+
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for export")
|
21 |
+
parser.add_argument("--device", type=str, default="cpu", help="Device for export")
|
22 |
+
parser.add_argument("--opset-version", type=int, default=11, help="ONNX opset version")
|
23 |
+
parser.add_argument("--seed", type=int, default=42)
|
24 |
+
return parser.parse_args()
|
25 |
+
|
26 |
+
|
27 |
+
def export_to_onnx(
|
28 |
+
model: torch.nn.Module,
|
29 |
+
output_path: Path,
|
30 |
+
input_size: int,
|
31 |
+
batch_size: int,
|
32 |
+
device: torch.device,
|
33 |
+
opset_version: int = 11
|
34 |
+
) -> None:
|
35 |
+
"""Export PyTorch model to ONNX format."""
|
36 |
+
|
37 |
+
model.eval()
|
38 |
+
|
39 |
+
# Create dummy input tensor
|
40 |
+
dummy_input = torch.randn(batch_size, 3, input_size, input_size, device=device)
|
41 |
+
|
42 |
+
# Export to ONNX
|
43 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
44 |
+
|
45 |
+
torch.onnx.export(
|
46 |
+
model,
|
47 |
+
dummy_input,
|
48 |
+
str(output_path),
|
49 |
+
export_params=True,
|
50 |
+
opset_version=opset_version,
|
51 |
+
do_constant_folding=True,
|
52 |
+
input_names=['input'],
|
53 |
+
output_names=['output'],
|
54 |
+
dynamic_axes={
|
55 |
+
'input': {0: 'batch_size'},
|
56 |
+
'output': {0: 'batch_size'}
|
57 |
+
}
|
58 |
+
)
|
59 |
+
|
60 |
+
print(f"[SUCCESS] Model exported to {output_path}")
|
61 |
+
|
62 |
+
|
63 |
+
def verify_onnx_model(onnx_path: Path, pytorch_model: torch.nn.Module, input_size: int, device: torch.device) -> None:
|
64 |
+
"""Verify that ONNX model produces same outputs as PyTorch model."""
|
65 |
+
|
66 |
+
# Load ONNX model
|
67 |
+
onnx_model = onnx.load(str(onnx_path))
|
68 |
+
onnx.checker.check_model(onnx_model)
|
69 |
+
print("[SUCCESS] ONNX model is valid")
|
70 |
+
|
71 |
+
# Create ONNX Runtime session
|
72 |
+
ort_session = ort.InferenceSession(str(onnx_path))
|
73 |
+
|
74 |
+
# Create test input
|
75 |
+
test_input = torch.randn(1, 3, input_size, input_size, device=device)
|
76 |
+
|
77 |
+
# PyTorch inference
|
78 |
+
pytorch_model.eval()
|
79 |
+
with torch.no_grad():
|
80 |
+
pytorch_output = pytorch_model(test_input).cpu().numpy()
|
81 |
+
|
82 |
+
# ONNX inference
|
83 |
+
onnx_input = test_input.cpu().numpy()
|
84 |
+
onnx_output = ort_session.run(['output'], {'input': onnx_input})[0]
|
85 |
+
|
86 |
+
# Compare outputs
|
87 |
+
max_diff = np.max(np.abs(pytorch_output - onnx_output))
|
88 |
+
print(f"[SUCCESS] Max difference between PyTorch and ONNX outputs: {max_diff:.6f}")
|
89 |
+
|
90 |
+
if max_diff < 1e-5:
|
91 |
+
print("[SUCCESS] ONNX model verification successful!")
|
92 |
+
else:
|
93 |
+
print("[WARNING] Large difference detected - verify model compatibility")
|
94 |
+
|
95 |
+
|
96 |
+
def get_model_info(model: torch.nn.Module, input_size: int) -> None:
|
97 |
+
"""Print model information."""
|
98 |
+
|
99 |
+
# Count parameters
|
100 |
+
total_params = sum(p.numel() for p in model.parameters())
|
101 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
102 |
+
|
103 |
+
print(f"Model Information:")
|
104 |
+
print(f" Total parameters: {total_params:,}")
|
105 |
+
print(f" Trainable parameters: {trainable_params:,}")
|
106 |
+
print(f" Input size: {input_size}x{input_size}x3")
|
107 |
+
|
108 |
+
|
109 |
+
def main() -> None:
|
110 |
+
args = parse_args()
|
111 |
+
set_seed(args.seed)
|
112 |
+
|
113 |
+
device = torch.device(args.device)
|
114 |
+
|
115 |
+
# Load dataset info to get number of classes
|
116 |
+
data_dir = args.data_dir.expanduser().resolve()
|
117 |
+
_, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=args.seed)
|
118 |
+
num_classes = len(idx_to_class)
|
119 |
+
|
120 |
+
# Load model
|
121 |
+
print(f"[INFO] Loading {args.model} model...")
|
122 |
+
model = build_model(args.model, num_classes=num_classes, pretrained=False, freeze_backbone=False)
|
123 |
+
|
124 |
+
# Load checkpoint
|
125 |
+
try:
|
126 |
+
state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
|
127 |
+
except TypeError:
|
128 |
+
state_dict = torch.load(args.checkpoint, map_location=device)
|
129 |
+
|
130 |
+
model.load_state_dict(state_dict)
|
131 |
+
model.to(device)
|
132 |
+
|
133 |
+
# Print model info
|
134 |
+
get_model_info(model, args.input_size)
|
135 |
+
|
136 |
+
# Export to ONNX
|
137 |
+
print(f"[INFO] Exporting to ONNX...")
|
138 |
+
export_to_onnx(
|
139 |
+
model=model,
|
140 |
+
output_path=args.output,
|
141 |
+
input_size=args.input_size,
|
142 |
+
batch_size=args.batch_size,
|
143 |
+
device=device,
|
144 |
+
opset_version=args.opset_version
|
145 |
+
)
|
146 |
+
|
147 |
+
# Verify ONNX model
|
148 |
+
print(f"[INFO] Verifying ONNX model...")
|
149 |
+
verify_onnx_model(args.output, model, args.input_size, device)
|
150 |
+
|
151 |
+
# Print file size
|
152 |
+
file_size_mb = args.output.stat().st_size / (1024 * 1024)
|
153 |
+
print(f"[INFO] ONNX file size: {file_size_mb:.2f} MB")
|
154 |
+
|
155 |
+
print("[SUCCESS] Export completed successfully!")
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
main()
|
scripts/gradcam.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Grad-CAM utilities for Food-101 models."""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Dict, List
|
7 |
+
|
8 |
+
import albumentations as A
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from albumentations.pytorch import ToTensorV2
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from train import build_model, load_food101_splits, set_seed
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args() -> argparse.Namespace:
|
19 |
+
parser = argparse.ArgumentParser(description="Generate Grad-CAM visualizations")
|
20 |
+
parser.add_argument("--data-dir", type=Path, default=Path("data/raw/food-101"))
|
21 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
22 |
+
parser.add_argument("--model", choices=["baseline", "resnet50", "efficientnet_b0"], required=True)
|
23 |
+
parser.add_argument("--class-index", type=int, default=None, help="Target class index for Grad-CAM (defaults to prediction)")
|
24 |
+
parser.add_argument("--image", type=Path, required=True, help="Path to input image file")
|
25 |
+
parser.add_argument("--output", type=Path, required=True, help="Output heatmap path")
|
26 |
+
parser.add_argument("--image-size", type=int, default=224)
|
27 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
parser.add_argument("--seed", type=int, default=42)
|
29 |
+
parser.add_argument("--alpha", type=float, default=0.4, help="Blending factor for heatmap overlay")
|
30 |
+
return parser.parse_args()
|
31 |
+
|
32 |
+
|
33 |
+
def build_preprocess(image_size: int) -> A.BasicTransform:
|
34 |
+
return A.Compose(
|
35 |
+
[
|
36 |
+
A.Resize(height=image_size, width=image_size),
|
37 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
38 |
+
ToTensorV2(),
|
39 |
+
]
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def get_last_conv_layer(model: torch.nn.Module) -> torch.nn.Module:
|
44 |
+
# EfficientNet B0
|
45 |
+
if hasattr(model, "features") and hasattr(model.features, "_modules"):
|
46 |
+
return model.features[-1][-1] # Last block, last layer
|
47 |
+
|
48 |
+
# ResNet50
|
49 |
+
if hasattr(model, "layer4"):
|
50 |
+
return model.layer4[-1].conv3 # type: ignore[attr-defined]
|
51 |
+
|
52 |
+
# Baseline CNN
|
53 |
+
if hasattr(model, "features"):
|
54 |
+
for module in reversed(model.features):
|
55 |
+
if isinstance(module, torch.nn.Conv2d):
|
56 |
+
return module
|
57 |
+
|
58 |
+
# Generic fallback
|
59 |
+
if hasattr(model, "classifier") and isinstance(model.classifier, torch.nn.Sequential):
|
60 |
+
for module in reversed(model.classifier):
|
61 |
+
if isinstance(module, torch.nn.Conv2d):
|
62 |
+
return module
|
63 |
+
|
64 |
+
raise RuntimeError("Could not automatically determine last convolutional layer")
|
65 |
+
|
66 |
+
|
67 |
+
def generate_gradcam(
|
68 |
+
model: torch.nn.Module,
|
69 |
+
image_tensor: torch.Tensor,
|
70 |
+
target_class: int | None,
|
71 |
+
) -> np.ndarray:
|
72 |
+
gradients: List[torch.Tensor] = []
|
73 |
+
activations: List[torch.Tensor] = []
|
74 |
+
|
75 |
+
def backward_hook(module: torch.nn.Module, grad_input: tuple[torch.Tensor, ...], grad_output: tuple[torch.Tensor, ...]) -> None:
|
76 |
+
gradients.append(grad_output[0])
|
77 |
+
|
78 |
+
def forward_hook(module: torch.nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
|
79 |
+
activations.append(output)
|
80 |
+
|
81 |
+
target_layer = get_last_conv_layer(model)
|
82 |
+
handle_fwd = target_layer.register_forward_hook(forward_hook)
|
83 |
+
handle_bwd = target_layer.register_full_backward_hook(backward_hook)
|
84 |
+
|
85 |
+
try:
|
86 |
+
output = model(image_tensor)
|
87 |
+
if target_class is None:
|
88 |
+
target_class = int(output.argmax(dim=1).item())
|
89 |
+
loss = output[:, target_class].sum()
|
90 |
+
model.zero_grad()
|
91 |
+
loss.backward()
|
92 |
+
|
93 |
+
grads = gradients[0]
|
94 |
+
acts = activations[0]
|
95 |
+
weights = grads.mean(dim=(2, 3), keepdim=True)
|
96 |
+
cam = torch.relu((weights * acts).sum(dim=1, keepdim=True))
|
97 |
+
cam = torch.nn.functional.interpolate(cam, size=image_tensor.shape[2:], mode="bilinear", align_corners=False)
|
98 |
+
cam = cam.squeeze().detach().cpu().numpy()
|
99 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
100 |
+
return cam
|
101 |
+
finally:
|
102 |
+
handle_fwd.remove()
|
103 |
+
handle_bwd.remove()
|
104 |
+
|
105 |
+
|
106 |
+
def overlay_heatmap(original: np.ndarray, heatmap: np.ndarray, alpha: float) -> np.ndarray:
|
107 |
+
# Resize heatmap to match original image size
|
108 |
+
heatmap_resized = cv2.resize(heatmap, (original.shape[1], original.shape[0]))
|
109 |
+
heatmap_color = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
110 |
+
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
|
111 |
+
overlay = cv2.addWeighted(heatmap_color, alpha, original, 1 - alpha, 0)
|
112 |
+
return overlay
|
113 |
+
|
114 |
+
|
115 |
+
def main() -> None:
|
116 |
+
args = parse_args()
|
117 |
+
set_seed(args.seed)
|
118 |
+
|
119 |
+
device = torch.device(args.device)
|
120 |
+
model = build_model(args.model, num_classes=101, pretrained=False, freeze_backbone=False)
|
121 |
+
try:
|
122 |
+
state_dict = torch.load(args.checkpoint, map_location=device, weights_only=True)
|
123 |
+
except TypeError:
|
124 |
+
state_dict = torch.load(args.checkpoint, map_location=device)
|
125 |
+
model.load_state_dict(state_dict)
|
126 |
+
model.to(device)
|
127 |
+
model.eval()
|
128 |
+
|
129 |
+
preprocess = build_preprocess(args.image_size)
|
130 |
+
with args.image.open("rb") as f:
|
131 |
+
original = np.array(Image.open(f).convert("RGB"))
|
132 |
+
tensor = preprocess(image=original)["image"].unsqueeze(0).to(device)
|
133 |
+
|
134 |
+
heatmap = generate_gradcam(model, tensor, args.class_index)
|
135 |
+
overlay = overlay_heatmap(original, heatmap, alpha=args.alpha)
|
136 |
+
|
137 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
138 |
+
Image.fromarray(overlay).save(args.output)
|
139 |
+
print(f"Grad-CAM saved to {args.output}")
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
scripts/predict.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fast inference script using ONNX model for Food-101 classification."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Dict, List, Tuple
|
7 |
+
|
8 |
+
import albumentations as A
|
9 |
+
import numpy as np
|
10 |
+
import onnxruntime as ort
|
11 |
+
from albumentations.pytorch import ToTensorV2
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from train import load_food101_splits
|
15 |
+
|
16 |
+
|
17 |
+
class Food101Predictor:
|
18 |
+
"""Fast ONNX-based predictor for Food-101 classification."""
|
19 |
+
|
20 |
+
def __init__(self, onnx_path: Path, class_names: List[str], providers: List[str] = None):
|
21 |
+
"""Initialize predictor with ONNX model."""
|
22 |
+
|
23 |
+
self.class_names = class_names
|
24 |
+
self.num_classes = len(class_names)
|
25 |
+
|
26 |
+
# Initialize ONNX Runtime session with optimal providers
|
27 |
+
if providers is None:
|
28 |
+
providers = ['CPUExecutionProvider']
|
29 |
+
|
30 |
+
self.session = ort.InferenceSession(str(onnx_path), providers=providers)
|
31 |
+
|
32 |
+
# Get input/output info
|
33 |
+
self.input_name = self.session.get_inputs()[0].name
|
34 |
+
self.output_name = self.session.get_outputs()[0].name
|
35 |
+
|
36 |
+
# Create preprocessing transform
|
37 |
+
self.transform = A.Compose([
|
38 |
+
A.Resize(height=224, width=224),
|
39 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
40 |
+
ToTensorV2(),
|
41 |
+
])
|
42 |
+
|
43 |
+
print(f"[INFO] Predictor initialized with {len(class_names)} classes")
|
44 |
+
print(f"[INFO] ONNX Runtime providers: {self.session.get_providers()}")
|
45 |
+
|
46 |
+
def preprocess_image(self, image_path: Path) -> np.ndarray:
|
47 |
+
"""Preprocess image for inference."""
|
48 |
+
|
49 |
+
with image_path.open("rb") as f:
|
50 |
+
image = Image.open(f).convert("RGB")
|
51 |
+
|
52 |
+
# Convert to numpy array
|
53 |
+
image_array = np.array(image)
|
54 |
+
|
55 |
+
# Apply transforms
|
56 |
+
transformed = self.transform(image=image_array)
|
57 |
+
tensor = transformed["image"]
|
58 |
+
|
59 |
+
# Add batch dimension and convert to numpy
|
60 |
+
batch = tensor.unsqueeze(0).numpy()
|
61 |
+
return batch
|
62 |
+
|
63 |
+
def predict(self, image_path: Path, top_k: int = 5) -> Tuple[List[str], List[float], float]:
|
64 |
+
"""
|
65 |
+
Predict food class for given image.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
predictions: List of top-k class names
|
69 |
+
probabilities: List of top-k probabilities
|
70 |
+
inference_time: Time in milliseconds
|
71 |
+
"""
|
72 |
+
|
73 |
+
# Preprocess
|
74 |
+
start_time = time.time()
|
75 |
+
input_batch = self.preprocess_image(image_path)
|
76 |
+
|
77 |
+
# Run inference
|
78 |
+
outputs = self.session.run([self.output_name], {self.input_name: input_batch})[0]
|
79 |
+
|
80 |
+
# Apply softmax to get probabilities
|
81 |
+
exp_outputs = np.exp(outputs - np.max(outputs, axis=1, keepdims=True))
|
82 |
+
probabilities = exp_outputs / np.sum(exp_outputs, axis=1, keepdims=True)
|
83 |
+
|
84 |
+
# Get top-k predictions
|
85 |
+
top_indices = np.argsort(probabilities[0])[::-1][:top_k]
|
86 |
+
top_probs = probabilities[0][top_indices].tolist()
|
87 |
+
top_classes = [self.class_names[i] for i in top_indices]
|
88 |
+
|
89 |
+
inference_time = (time.time() - start_time) * 1000 # Convert to ms
|
90 |
+
|
91 |
+
return top_classes, top_probs, inference_time
|
92 |
+
|
93 |
+
def predict_batch(self, image_paths: List[Path], top_k: int = 5) -> List[Dict]:
|
94 |
+
"""Predict multiple images at once."""
|
95 |
+
|
96 |
+
results = []
|
97 |
+
for image_path in image_paths:
|
98 |
+
classes, probs, time_ms = self.predict(image_path, top_k)
|
99 |
+
results.append({
|
100 |
+
'image_path': str(image_path),
|
101 |
+
'predictions': classes,
|
102 |
+
'probabilities': probs,
|
103 |
+
'inference_time_ms': time_ms
|
104 |
+
})
|
105 |
+
|
106 |
+
return results
|
107 |
+
|
108 |
+
|
109 |
+
def parse_args() -> argparse.Namespace:
|
110 |
+
parser = argparse.ArgumentParser(description="Fast inference with ONNX model")
|
111 |
+
parser.add_argument("--model", type=Path, required=True, help="Path to ONNX model file")
|
112 |
+
parser.add_argument("--image", type=Path, required=True, help="Path to input image")
|
113 |
+
parser.add_argument("--data-dir", type=Path, default=Path("food-101/food-101"), help="Dataset root for class names")
|
114 |
+
parser.add_argument("--top-k", type=int, default=5, help="Number of top predictions to show")
|
115 |
+
parser.add_argument("--providers", nargs="*", default=None,
|
116 |
+
help="ONNX Runtime providers (e.g., CPUExecutionProvider)")
|
117 |
+
parser.add_argument("--seed", type=int, default=42)
|
118 |
+
return parser.parse_args()
|
119 |
+
|
120 |
+
|
121 |
+
def benchmark_inference(predictor: Food101Predictor, image_path: Path, num_runs: int = 10) -> None:
|
122 |
+
"""Benchmark inference speed."""
|
123 |
+
|
124 |
+
print(f"[INFO] Benchmarking inference with {num_runs} runs...")
|
125 |
+
|
126 |
+
times = []
|
127 |
+
for i in range(num_runs):
|
128 |
+
_, _, inference_time = predictor.predict(image_path, top_k=1)
|
129 |
+
times.append(inference_time)
|
130 |
+
if i == 0:
|
131 |
+
print(f"[INFO] First run (cold start): {inference_time:.2f}ms")
|
132 |
+
|
133 |
+
# Statistics
|
134 |
+
mean_time = np.mean(times[1:]) # Exclude cold start
|
135 |
+
std_time = np.std(times[1:])
|
136 |
+
min_time = np.min(times[1:])
|
137 |
+
max_time = np.max(times[1:])
|
138 |
+
|
139 |
+
print(f"[BENCHMARK] Average inference time: {mean_time:.2f} ± {std_time:.2f}ms")
|
140 |
+
print(f"[BENCHMARK] Min: {min_time:.2f}ms, Max: {max_time:.2f}ms")
|
141 |
+
|
142 |
+
if mean_time < 100:
|
143 |
+
print(f"[SUCCESS] Target latency achieved! ({mean_time:.2f}ms < 100ms)")
|
144 |
+
else:
|
145 |
+
print(f"[WARNING] Target latency not met ({mean_time:.2f}ms >= 100ms)")
|
146 |
+
|
147 |
+
|
148 |
+
def main() -> None:
|
149 |
+
args = parse_args()
|
150 |
+
|
151 |
+
# Load class names from dataset
|
152 |
+
data_dir = args.data_dir.expanduser().resolve()
|
153 |
+
_, _, _, idx_to_class = load_food101_splits(data_dir, val_split=0.1, seed=args.seed)
|
154 |
+
class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
|
155 |
+
|
156 |
+
# Initialize predictor
|
157 |
+
predictor = Food101Predictor(args.model, class_names, providers=args.providers)
|
158 |
+
|
159 |
+
# Run prediction
|
160 |
+
print(f"[INFO] Predicting image: {args.image}")
|
161 |
+
predictions, probabilities, inference_time = predictor.predict(args.image, args.top_k)
|
162 |
+
|
163 |
+
# Display results
|
164 |
+
print(f"\n[RESULTS] Inference time: {inference_time:.2f}ms")
|
165 |
+
print("Top predictions:")
|
166 |
+
for i, (class_name, prob) in enumerate(zip(predictions, probabilities), 1):
|
167 |
+
print(f" {i}. {class_name}: {prob:.4f} ({prob*100:.2f}%)")
|
168 |
+
|
169 |
+
# Run benchmark
|
170 |
+
print()
|
171 |
+
benchmark_inference(predictor, args.image)
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
main()
|
scripts/train.py
ADDED
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Training script for Food-101 supporting baseline, transfer, and advanced training utilities."""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import csv
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
11 |
+
|
12 |
+
import albumentations as A
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from albumentations.pytorch import ToTensorV2
|
16 |
+
from PIL import Image
|
17 |
+
from torch import nn
|
18 |
+
from torch.amp import GradScaler, autocast
|
19 |
+
from torch.optim import Adam
|
20 |
+
from torch.utils.data import DataLoader, Dataset
|
21 |
+
from torchvision import models
|
22 |
+
from torchvision.models import EfficientNet_B0_Weights, ResNet50_Weights
|
23 |
+
|
24 |
+
try: # Optional dependency used for experiment tracking.
|
25 |
+
import wandb # type: ignore
|
26 |
+
except ImportError: # pragma: no cover - handled at runtime when library missing.
|
27 |
+
wandb = None
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass(frozen=True)
|
31 |
+
class Sample:
|
32 |
+
"""Minimal container storing an image path and class index."""
|
33 |
+
|
34 |
+
path: Path
|
35 |
+
label: int
|
36 |
+
|
37 |
+
|
38 |
+
class Food101Dataset(Dataset):
|
39 |
+
"""Custom Dataset that loads Food-101 images and applies augmentations."""
|
40 |
+
|
41 |
+
def __init__(self, samples: Sequence[Sample], transform: A.BasicTransform | None = None) -> None:
|
42 |
+
self.samples = list(samples)
|
43 |
+
self.transform = transform
|
44 |
+
|
45 |
+
def __len__(self) -> int:
|
46 |
+
return len(self.samples)
|
47 |
+
|
48 |
+
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
|
49 |
+
sample = self.samples[index]
|
50 |
+
with sample.path.open("rb") as file:
|
51 |
+
# Convert PIL image to NumPy array so Albumentations can process it.
|
52 |
+
array = np.array(Image.open(file).convert("RGB"))
|
53 |
+
if self.transform is not None:
|
54 |
+
array = self.transform(image=array)["image"]
|
55 |
+
return array, sample.label
|
56 |
+
|
57 |
+
|
58 |
+
class BaselineCNN(nn.Module):
|
59 |
+
"""Lightweight CNN baseline with three feature stages and global pooling."""
|
60 |
+
|
61 |
+
def __init__(self, num_classes: int) -> None:
|
62 |
+
super().__init__()
|
63 |
+
self.features = nn.Sequential(
|
64 |
+
_conv_block(3, 32),
|
65 |
+
_conv_block(32, 64),
|
66 |
+
nn.MaxPool2d(2),
|
67 |
+
_conv_block(64, 128),
|
68 |
+
nn.MaxPool2d(2),
|
69 |
+
_conv_block(128, 256),
|
70 |
+
nn.MaxPool2d(2),
|
71 |
+
)
|
72 |
+
self.classifier = nn.Sequential(
|
73 |
+
nn.AdaptiveAvgPool2d(1),
|
74 |
+
nn.Flatten(),
|
75 |
+
nn.Dropout(0.3),
|
76 |
+
nn.Linear(256, num_classes),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
80 |
+
x = self.features(x)
|
81 |
+
return self.classifier(x)
|
82 |
+
|
83 |
+
|
84 |
+
def _conv_block(in_channels: int, out_channels: int) -> nn.Sequential:
|
85 |
+
"""Creates a Conv-BN-ReLU block reused across the network."""
|
86 |
+
|
87 |
+
return nn.Sequential(
|
88 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
89 |
+
nn.BatchNorm2d(out_channels),
|
90 |
+
nn.ReLU(inplace=True),
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def parse_args() -> argparse.Namespace:
|
95 |
+
"""Parses command line arguments controlling training behavior."""
|
96 |
+
|
97 |
+
parser = argparse.ArgumentParser(description="Train Food-101 image classifiers")
|
98 |
+
parser.add_argument("--data-dir", type=Path, default=Path("data/raw/food-101"), help="Root directory of Food-101 dataset")
|
99 |
+
parser.add_argument("--model", type=str, choices=["baseline", "resnet50", "efficientnet_b0"], default="baseline", help="Model architecture to train")
|
100 |
+
parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
|
101 |
+
parser.add_argument("--batch-size", type=int, default=64, help="Mini-batch size")
|
102 |
+
parser.add_argument("--learning-rate", type=float, default=1e-3, help="Optimizer learning rate")
|
103 |
+
parser.add_argument("--val-split", type=float, default=0.1, help="Fraction of train set used for validation")
|
104 |
+
parser.add_argument("--num-workers", type=int, default=4, help="DataLoader worker processes")
|
105 |
+
parser.add_argument("--image-size", type=int, default=224, help="Square image size fed to the network")
|
106 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
107 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Computation device")
|
108 |
+
parser.add_argument("--pretrained", action="store_true", default=None, help="Use pretrained weights when available (transfer models)")
|
109 |
+
parser.add_argument("--no-pretrained", action="store_false", dest="pretrained", help="Disable pretrained weights for transfer models")
|
110 |
+
parser.add_argument("--freeze-backbone", action="store_true", help="Freeze feature extractor when using transfer models")
|
111 |
+
parser.add_argument("--experiment-name", type=str, default=None, help="Optional experiment name for checkpoints/logs")
|
112 |
+
parser.add_argument("--checkpoint-dir", type=Path, default=Path("checkpoints"), help="Where to store model checkpoints")
|
113 |
+
parser.add_argument("--log-dir", type=Path, default=Path("logs"), help="Where to store training logs")
|
114 |
+
parser.add_argument("--early-stop-patience", type=int, default=5, help="Epochs to wait before stopping after no validation improvement")
|
115 |
+
parser.add_argument("--early-stop-min-delta", type=float, default=0.0, help="Minimum change to qualify as an improvement")
|
116 |
+
parser.add_argument("--early-stop-metric", choices=["accuracy", "loss"], default="accuracy", help="Validation metric used for early stopping")
|
117 |
+
parser.add_argument("--grad-clip-norm", type=float, default=None, help="Gradient clipping norm (L2). Disabled if not set")
|
118 |
+
parser.add_argument("--use-amp", action="store_true", help="Enable mixed precision training (requires CUDA)")
|
119 |
+
parser.add_argument("--wandb", dest="use_wandb", action="store_true", help="Log metrics to Weights & Biases")
|
120 |
+
parser.add_argument("--no-wandb", dest="use_wandb", action="store_false", help="Disable Weights & Biases logging")
|
121 |
+
parser.add_argument("--wandb-project", type=str, default=None, help="Weights & Biases project name")
|
122 |
+
parser.add_argument("--wandb-entity", type=str, default=None, help="Weights & Biases entity (team/user)")
|
123 |
+
parser.add_argument("--wandb-run-name", type=str, default=None, help="Weights & Biases run name override")
|
124 |
+
args = parser.parse_args()
|
125 |
+
|
126 |
+
# Default to pretrained weights for transfer models unless user explicitly disables them.
|
127 |
+
if args.pretrained is None:
|
128 |
+
args.pretrained = args.model != "baseline"
|
129 |
+
if not hasattr(args, "use_wandb"):
|
130 |
+
args.use_wandb = False
|
131 |
+
return args
|
132 |
+
|
133 |
+
|
134 |
+
def set_seed(seed: int) -> None:
|
135 |
+
"""Fixes random seeds for reproducibility across libraries."""
|
136 |
+
|
137 |
+
random.seed(seed)
|
138 |
+
np.random.seed(seed)
|
139 |
+
torch.manual_seed(seed)
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
torch.cuda.manual_seed_all(seed)
|
142 |
+
torch.backends.cudnn.deterministic = True
|
143 |
+
torch.backends.cudnn.benchmark = False
|
144 |
+
|
145 |
+
|
146 |
+
def build_transforms(image_size: int) -> Tuple[A.BasicTransform, A.BasicTransform]:
|
147 |
+
"""Constructs augmentation pipelines for train and evaluation splits."""
|
148 |
+
|
149 |
+
normalize = A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
150 |
+
train_transform = A.Compose(
|
151 |
+
[
|
152 |
+
A.RandomResizedCrop(
|
153 |
+
size=(image_size, image_size),
|
154 |
+
scale=(0.8, 1.0),
|
155 |
+
ratio=(0.75, 1.33),
|
156 |
+
),
|
157 |
+
A.HorizontalFlip(p=0.5),
|
158 |
+
A.Affine(
|
159 |
+
scale=(0.9, 1.1),
|
160 |
+
translate_percent=(-0.05, 0.05),
|
161 |
+
rotate=(-15, 15),
|
162 |
+
shear=(0.0, 0.0),
|
163 |
+
p=0.3,
|
164 |
+
),
|
165 |
+
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3),
|
166 |
+
normalize,
|
167 |
+
ToTensorV2(),
|
168 |
+
]
|
169 |
+
)
|
170 |
+
eval_transform = A.Compose(
|
171 |
+
[
|
172 |
+
A.Resize(height=image_size + 32, width=image_size + 32),
|
173 |
+
A.CenterCrop(height=image_size, width=image_size),
|
174 |
+
normalize,
|
175 |
+
ToTensorV2(),
|
176 |
+
]
|
177 |
+
)
|
178 |
+
return train_transform, eval_transform
|
179 |
+
|
180 |
+
|
181 |
+
def build_model(
|
182 |
+
model_name: str,
|
183 |
+
num_classes: int,
|
184 |
+
pretrained: bool,
|
185 |
+
freeze_backbone: bool,
|
186 |
+
) -> nn.Module:
|
187 |
+
"""Factory that returns the requested architecture configured for Food-101."""
|
188 |
+
|
189 |
+
if model_name == "baseline":
|
190 |
+
model = BaselineCNN(num_classes=num_classes)
|
191 |
+
elif model_name == "resnet50":
|
192 |
+
weights = ResNet50_Weights.DEFAULT if pretrained else None
|
193 |
+
model = models.resnet50(weights=weights)
|
194 |
+
if freeze_backbone:
|
195 |
+
for param in model.parameters():
|
196 |
+
param.requires_grad = False
|
197 |
+
in_features = model.fc.in_features
|
198 |
+
model.fc = nn.Linear(in_features, num_classes)
|
199 |
+
elif model_name == "efficientnet_b0":
|
200 |
+
weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
|
201 |
+
model = models.efficientnet_b0(weights=weights)
|
202 |
+
if freeze_backbone:
|
203 |
+
for param in model.parameters():
|
204 |
+
param.requires_grad = False
|
205 |
+
in_features = model.classifier[-1].in_features
|
206 |
+
model.classifier[-1] = nn.Linear(in_features, num_classes)
|
207 |
+
else:
|
208 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
209 |
+
|
210 |
+
return model
|
211 |
+
|
212 |
+
|
213 |
+
def load_food101_splits(data_dir: Path, val_split: float, seed: int) -> Tuple[List[Sample], List[Sample], List[Sample], Dict[int, str]]:
|
214 |
+
"""Loads Food-101 metadata and returns train/val/test splits."""
|
215 |
+
|
216 |
+
images_dir = data_dir / "images"
|
217 |
+
meta_dir = data_dir / "meta"
|
218 |
+
classes = _read_classes(meta_dir / "classes.txt")
|
219 |
+
class_to_idx = {name: idx for idx, name in enumerate(classes)}
|
220 |
+
|
221 |
+
with (meta_dir / "train.json").open() as f:
|
222 |
+
train_meta = json.load(f)
|
223 |
+
with (meta_dir / "test.json").open() as f:
|
224 |
+
test_meta = json.load(f)
|
225 |
+
|
226 |
+
rng = random.Random(seed)
|
227 |
+
|
228 |
+
train_samples: List[Sample] = []
|
229 |
+
val_samples: List[Sample] = []
|
230 |
+
for cls_name, items in train_meta.items():
|
231 |
+
paths = list(items)
|
232 |
+
rng.shuffle(paths)
|
233 |
+
val_count = max(1, int(len(paths) * val_split))
|
234 |
+
val_subset = paths[:val_count]
|
235 |
+
train_subset = paths[val_count:]
|
236 |
+
train_samples.extend(_build_samples(train_subset, images_dir, class_to_idx[cls_name]))
|
237 |
+
val_samples.extend(_build_samples(val_subset, images_dir, class_to_idx[cls_name]))
|
238 |
+
|
239 |
+
test_samples = []
|
240 |
+
for cls_name, items in test_meta.items():
|
241 |
+
test_samples.extend(_build_samples(items, images_dir, class_to_idx[cls_name]))
|
242 |
+
|
243 |
+
idx_to_class = {idx: name for name, idx in class_to_idx.items()}
|
244 |
+
return train_samples, val_samples, test_samples, idx_to_class
|
245 |
+
|
246 |
+
|
247 |
+
def _build_samples(items: Iterable[str], images_dir: Path, label: int) -> List[Sample]:
|
248 |
+
"""Creates Sample objects from relative paths and a target label."""
|
249 |
+
|
250 |
+
samples = []
|
251 |
+
for item in items:
|
252 |
+
path = images_dir / f"{item}.jpg"
|
253 |
+
samples.append(Sample(path=path, label=label))
|
254 |
+
return samples
|
255 |
+
|
256 |
+
|
257 |
+
def _read_classes(path: Path) -> List[str]:
|
258 |
+
"""Reads the list of Food-101 class names from disk."""
|
259 |
+
|
260 |
+
with path.open() as handle:
|
261 |
+
return [line.strip() for line in handle if line.strip()]
|
262 |
+
|
263 |
+
|
264 |
+
def create_dataloaders(
|
265 |
+
train_samples: Sequence[Sample],
|
266 |
+
val_samples: Sequence[Sample],
|
267 |
+
test_samples: Sequence[Sample],
|
268 |
+
train_transform: A.BasicTransform,
|
269 |
+
eval_transform: A.BasicTransform,
|
270 |
+
batch_size: int,
|
271 |
+
num_workers: int,
|
272 |
+
) -> Tuple[DataLoader, DataLoader, DataLoader]:
|
273 |
+
"""Wraps datasets with PyTorch DataLoaders."""
|
274 |
+
|
275 |
+
pin_memory = torch.cuda.is_available()
|
276 |
+
train_dataset = Food101Dataset(train_samples, transform=train_transform)
|
277 |
+
val_dataset = Food101Dataset(val_samples, transform=eval_transform)
|
278 |
+
test_dataset = Food101Dataset(test_samples, transform=eval_transform)
|
279 |
+
|
280 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
281 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
282 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
283 |
+
return train_loader, val_loader, test_loader
|
284 |
+
|
285 |
+
|
286 |
+
def write_metrics(log_path: Path, rows: Sequence[Dict[str, object]]) -> None:
|
287 |
+
"""Appends metric rows to the CSV log."""
|
288 |
+
|
289 |
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
290 |
+
file_exists = log_path.exists()
|
291 |
+
fieldnames = ["model", "experiment", "epoch", "split", "loss", "accuracy"]
|
292 |
+
with log_path.open("a", newline="") as csvfile:
|
293 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
294 |
+
if not file_exists:
|
295 |
+
writer.writeheader()
|
296 |
+
writer.writerows(rows)
|
297 |
+
|
298 |
+
|
299 |
+
def maybe_init_wandb(args: argparse.Namespace, config_extra: Dict[str, object]) -> Optional[object]:
|
300 |
+
"""Initializes Weights & Biases run if requested and available."""
|
301 |
+
|
302 |
+
if not args.use_wandb:
|
303 |
+
return None
|
304 |
+
if args.wandb_project is None:
|
305 |
+
print("[wandb] Project not specified; skipping W&B logging.")
|
306 |
+
return None
|
307 |
+
if wandb is None:
|
308 |
+
print("[wandb] Package not installed; skipping W&B logging.")
|
309 |
+
return None
|
310 |
+
|
311 |
+
run_name = args.wandb_run_name or args.experiment_name or args.model
|
312 |
+
config = {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}
|
313 |
+
config.update(config_extra)
|
314 |
+
run = wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=run_name, config=config)
|
315 |
+
return run
|
316 |
+
|
317 |
+
|
318 |
+
def train_one_epoch(
|
319 |
+
model: nn.Module,
|
320 |
+
dataloader: DataLoader,
|
321 |
+
criterion: nn.Module,
|
322 |
+
optimizer: torch.optim.Optimizer,
|
323 |
+
device: torch.device,
|
324 |
+
scaler: GradScaler,
|
325 |
+
grad_clip_norm: Optional[float],
|
326 |
+
amp_enabled: bool,
|
327 |
+
clip_params: Sequence[torch.nn.Parameter],
|
328 |
+
amp_device_type: str,
|
329 |
+
) -> Tuple[float, float]:
|
330 |
+
"""Runs one training epoch and returns loss and accuracy."""
|
331 |
+
|
332 |
+
model.train()
|
333 |
+
running_loss = 0.0
|
334 |
+
correct = 0
|
335 |
+
total = 0
|
336 |
+
|
337 |
+
for inputs, targets in dataloader:
|
338 |
+
inputs = inputs.to(device)
|
339 |
+
targets = targets.to(device)
|
340 |
+
|
341 |
+
# Forward + backward pass and optimizer update.
|
342 |
+
optimizer.zero_grad()
|
343 |
+
with autocast(amp_device_type, enabled=amp_enabled):
|
344 |
+
outputs = model(inputs)
|
345 |
+
loss = criterion(outputs, targets)
|
346 |
+
|
347 |
+
if scaler.is_enabled():
|
348 |
+
scaler.scale(loss).backward()
|
349 |
+
if grad_clip_norm is not None:
|
350 |
+
scaler.unscale_(optimizer)
|
351 |
+
torch.nn.utils.clip_grad_norm_(clip_params, grad_clip_norm)
|
352 |
+
scaler.step(optimizer)
|
353 |
+
scaler.update()
|
354 |
+
else:
|
355 |
+
loss.backward()
|
356 |
+
if grad_clip_norm is not None:
|
357 |
+
torch.nn.utils.clip_grad_norm_(clip_params, grad_clip_norm)
|
358 |
+
optimizer.step()
|
359 |
+
|
360 |
+
running_loss += loss.item() * inputs.size(0)
|
361 |
+
preds = outputs.argmax(dim=1)
|
362 |
+
correct += (preds == targets).sum().item()
|
363 |
+
total += targets.size(0)
|
364 |
+
|
365 |
+
epoch_loss = running_loss / max(total, 1)
|
366 |
+
epoch_acc = correct / max(total, 1)
|
367 |
+
return epoch_loss, epoch_acc
|
368 |
+
|
369 |
+
|
370 |
+
def evaluate(
|
371 |
+
model: nn.Module,
|
372 |
+
dataloader: DataLoader,
|
373 |
+
criterion: nn.Module,
|
374 |
+
device: torch.device,
|
375 |
+
amp_enabled: bool,
|
376 |
+
amp_device_type: str,
|
377 |
+
) -> Tuple[float, float]:
|
378 |
+
"""Evaluates the model and returns loss and accuracy."""
|
379 |
+
|
380 |
+
model.eval()
|
381 |
+
running_loss = 0.0
|
382 |
+
correct = 0
|
383 |
+
total = 0
|
384 |
+
|
385 |
+
with torch.no_grad():
|
386 |
+
for inputs, targets in dataloader:
|
387 |
+
inputs = inputs.to(device)
|
388 |
+
targets = targets.to(device)
|
389 |
+
with autocast(amp_device_type, enabled=amp_enabled):
|
390 |
+
outputs = model(inputs)
|
391 |
+
loss = criterion(outputs, targets)
|
392 |
+
running_loss += loss.item() * inputs.size(0)
|
393 |
+
preds = outputs.argmax(dim=1)
|
394 |
+
correct += (preds == targets).sum().item()
|
395 |
+
total += targets.size(0)
|
396 |
+
|
397 |
+
epoch_loss = running_loss / max(total, 1)
|
398 |
+
epoch_acc = correct / max(total, 1)
|
399 |
+
return epoch_loss, epoch_acc
|
400 |
+
|
401 |
+
|
402 |
+
def main() -> None:
|
403 |
+
args = parse_args()
|
404 |
+
set_seed(args.seed)
|
405 |
+
|
406 |
+
data_dir = args.data_dir.expanduser().resolve()
|
407 |
+
if not data_dir.exists():
|
408 |
+
raise FileNotFoundError(f"Dataset directory not found: {data_dir}")
|
409 |
+
|
410 |
+
# Build samples for each split based on Food-101 metadata.
|
411 |
+
train_samples, val_samples, test_samples, idx_to_class = load_food101_splits(data_dir, args.val_split, args.seed)
|
412 |
+
num_classes = len(idx_to_class)
|
413 |
+
|
414 |
+
# Prepare augmentations and dataloaders.
|
415 |
+
train_transform, eval_transform = build_transforms(args.image_size)
|
416 |
+
train_loader, val_loader, test_loader = create_dataloaders(
|
417 |
+
train_samples,
|
418 |
+
val_samples,
|
419 |
+
test_samples,
|
420 |
+
train_transform,
|
421 |
+
eval_transform,
|
422 |
+
args.batch_size,
|
423 |
+
args.num_workers,
|
424 |
+
)
|
425 |
+
|
426 |
+
# Initialize model, loss, and optimizer on the desired device.
|
427 |
+
device = torch.device(args.device)
|
428 |
+
model = build_model(
|
429 |
+
model_name=args.model,
|
430 |
+
num_classes=num_classes,
|
431 |
+
pretrained=args.pretrained,
|
432 |
+
freeze_backbone=args.freeze_backbone,
|
433 |
+
).to(device)
|
434 |
+
criterion = nn.CrossEntropyLoss()
|
435 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
436 |
+
if not trainable_params:
|
437 |
+
trainable_params = list(model.parameters())
|
438 |
+
optimizer = Adam(trainable_params, lr=args.learning_rate)
|
439 |
+
|
440 |
+
amp_enabled = args.use_amp and device.type == "cuda"
|
441 |
+
amp_device_type = "cuda" if device.type == "cuda" else "cpu"
|
442 |
+
scaler = GradScaler(amp_device_type, enabled=amp_enabled)
|
443 |
+
clip_params: Sequence[torch.nn.Parameter] = trainable_params
|
444 |
+
|
445 |
+
checkpoint_dir = args.checkpoint_dir
|
446 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
447 |
+
|
448 |
+
experiment_name = args.experiment_name or args.model
|
449 |
+
checkpoint_path = checkpoint_dir / f"{experiment_name}.pth"
|
450 |
+
log_path = args.log_dir / f"{experiment_name}.csv"
|
451 |
+
|
452 |
+
monitor_max = args.early_stop_metric == "accuracy"
|
453 |
+
best_metric = -float("inf") if monitor_max else float("inf")
|
454 |
+
patience_counter = 0
|
455 |
+
best_val_acc = 0.0
|
456 |
+
epochs_completed = 0
|
457 |
+
|
458 |
+
wandb_run = maybe_init_wandb(
|
459 |
+
args,
|
460 |
+
{
|
461 |
+
"num_classes": num_classes,
|
462 |
+
"train_samples": len(train_samples),
|
463 |
+
"val_samples": len(val_samples),
|
464 |
+
"test_samples": len(test_samples),
|
465 |
+
"amp_enabled": amp_enabled,
|
466 |
+
},
|
467 |
+
)
|
468 |
+
|
469 |
+
# Standard training loop with validation monitoring.
|
470 |
+
for epoch in range(1, args.epochs + 1):
|
471 |
+
train_loss, train_acc = train_one_epoch(
|
472 |
+
model,
|
473 |
+
train_loader,
|
474 |
+
criterion,
|
475 |
+
optimizer,
|
476 |
+
device,
|
477 |
+
scaler,
|
478 |
+
args.grad_clip_norm,
|
479 |
+
amp_enabled,
|
480 |
+
clip_params,
|
481 |
+
amp_device_type,
|
482 |
+
)
|
483 |
+
val_loss, val_acc = evaluate(model, val_loader, criterion, device, amp_enabled, amp_device_type)
|
484 |
+
print(
|
485 |
+
f"Epoch {epoch:02d}: train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
|
486 |
+
f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}"
|
487 |
+
)
|
488 |
+
write_metrics(
|
489 |
+
log_path,
|
490 |
+
[
|
491 |
+
{
|
492 |
+
"model": args.model,
|
493 |
+
"experiment": experiment_name,
|
494 |
+
"epoch": epoch,
|
495 |
+
"split": "train",
|
496 |
+
"loss": train_loss,
|
497 |
+
"accuracy": train_acc,
|
498 |
+
},
|
499 |
+
{
|
500 |
+
"model": args.model,
|
501 |
+
"experiment": experiment_name,
|
502 |
+
"epoch": epoch,
|
503 |
+
"split": "val",
|
504 |
+
"loss": val_loss,
|
505 |
+
"accuracy": val_acc,
|
506 |
+
},
|
507 |
+
],
|
508 |
+
)
|
509 |
+
if wandb_run is not None:
|
510 |
+
wandb_run.log(
|
511 |
+
{
|
512 |
+
"epoch": epoch,
|
513 |
+
"train/loss": train_loss,
|
514 |
+
"train/accuracy": train_acc,
|
515 |
+
"val/loss": val_loss,
|
516 |
+
"val/accuracy": val_acc,
|
517 |
+
},
|
518 |
+
step=epoch,
|
519 |
+
)
|
520 |
+
|
521 |
+
if val_acc > best_val_acc:
|
522 |
+
best_val_acc = val_acc
|
523 |
+
|
524 |
+
monitor_value = val_acc if monitor_max else val_loss
|
525 |
+
improved = (
|
526 |
+
monitor_value > best_metric + args.early_stop_min_delta
|
527 |
+
if monitor_max
|
528 |
+
else monitor_value < best_metric - args.early_stop_min_delta
|
529 |
+
)
|
530 |
+
if improved:
|
531 |
+
best_metric = monitor_value
|
532 |
+
patience_counter = 0
|
533 |
+
torch.save(model.state_dict(), checkpoint_path)
|
534 |
+
else:
|
535 |
+
patience_counter += 1
|
536 |
+
if patience_counter >= args.early_stop_patience:
|
537 |
+
print(
|
538 |
+
f"Early stopping triggered at epoch {epoch}. "
|
539 |
+
f"Best {args.early_stop_metric}: {best_metric:.4f}"
|
540 |
+
)
|
541 |
+
epochs_completed = epoch
|
542 |
+
break
|
543 |
+
|
544 |
+
epochs_completed = epoch
|
545 |
+
else:
|
546 |
+
epochs_completed = args.epochs
|
547 |
+
|
548 |
+
# Ensure we persist weights even if validation never improved.
|
549 |
+
if not checkpoint_path.exists():
|
550 |
+
torch.save(model.state_dict(), checkpoint_path)
|
551 |
+
|
552 |
+
# Load best checkpoint before final evaluation.
|
553 |
+
try:
|
554 |
+
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
|
555 |
+
except TypeError:
|
556 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
557 |
+
model.load_state_dict(state_dict)
|
558 |
+
|
559 |
+
# Final evaluation on the held-out test set.
|
560 |
+
test_loss, test_acc = evaluate(model, test_loader, criterion, device, amp_enabled, amp_device_type)
|
561 |
+
print(f"Test metrics - loss: {test_loss:.4f} accuracy: {test_acc:.4f}")
|
562 |
+
write_metrics(
|
563 |
+
log_path,
|
564 |
+
[
|
565 |
+
{
|
566 |
+
"model": args.model,
|
567 |
+
"experiment": experiment_name,
|
568 |
+
"epoch": epochs_completed,
|
569 |
+
"split": "test",
|
570 |
+
"loss": test_loss,
|
571 |
+
"accuracy": test_acc,
|
572 |
+
}
|
573 |
+
],
|
574 |
+
)
|
575 |
+
|
576 |
+
if wandb_run is not None:
|
577 |
+
wandb_run.log(
|
578 |
+
{
|
579 |
+
"test/loss": test_loss,
|
580 |
+
"test/accuracy": test_acc,
|
581 |
+
"best/val_accuracy": best_val_acc,
|
582 |
+
f"best/val_{args.early_stop_metric}": best_metric,
|
583 |
+
},
|
584 |
+
step=epochs_completed,
|
585 |
+
)
|
586 |
+
wandb_run.finish()
|
587 |
+
|
588 |
+
|
589 |
+
if __name__ == "__main__":
|
590 |
+
main()
|