AshProg commited on
Commit
bb672b3
Β·
verified Β·
1 Parent(s): f824450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -153
app.py CHANGED
@@ -1,153 +1,197 @@
1
- """
2
- Gradio App for Bird Species Classification
3
- Deployed on Hugging Face Spaces
4
- """
5
-
6
- import gradio as gr
7
- import torch
8
- import torch.nn as nn
9
- from torchvision import transforms
10
- from torchvision.models import convnext_base
11
- from PIL import Image
12
- import json
13
-
14
- # Load class names
15
- with open('class_names.json', 'r') as f:
16
- class_names = json.load(f)
17
-
18
- # Device configuration
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
-
21
- # Create model architecture (same as training)
22
- def create_model(num_classes=200):
23
- """Create ConvNeXt model with same architecture as training"""
24
- model = convnext_base(weights=None)
25
-
26
- # Same classifier architecture as training
27
- num_ftrs = model.classifier[2].in_features
28
- model.classifier = nn.Sequential(
29
- nn.Flatten(1),
30
- nn.LayerNorm((num_ftrs,)),
31
- nn.Dropout(0.6),
32
- nn.Linear(num_ftrs, 512),
33
- nn.GELU(),
34
- nn.Dropout(0.5),
35
- nn.Linear(512, num_classes)
36
- )
37
-
38
- return model
39
-
40
- # Load the trained model
41
- print("Loading model...")
42
- model = create_model(num_classes=200)
43
-
44
- # Load weights
45
- checkpoint = torch.load('models/final_model.pth', map_location=device)
46
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
47
- model.load_state_dict(checkpoint['model_state_dict'])
48
- if 'val_acc' in checkpoint:
49
- val_acc = checkpoint['val_acc']
50
- print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
51
- else:
52
- model.load_state_dict(checkpoint)
53
- print("Model loaded!")
54
-
55
- model = model.to(device)
56
- model.eval()
57
-
58
- # Image preprocessing (same as validation transforms)
59
- transform = transforms.Compose([
60
- transforms.Resize((224, 224)),
61
- transforms.ToTensor(),
62
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
- ])
64
-
65
- def predict(image):
66
- """
67
- Make prediction on uploaded image
68
-
69
- Args:
70
- image: PIL Image
71
-
72
- Returns:
73
- dict: Top 5 predictions with confidence scores
74
- """
75
- # Preprocess image
76
- img_tensor = transform(image).unsqueeze(0).to(device)
77
-
78
- # Make prediction
79
- with torch.no_grad():
80
- outputs = model(img_tensor)
81
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
82
-
83
- # Get top 5 predictions
84
- top5_prob, top5_idx = torch.topk(probabilities, 5)
85
-
86
- # Format results
87
- results = {}
88
- for i in range(5):
89
- class_id = top5_idx[0][i].item()
90
- prob = top5_prob[0][i].item()
91
- species_name = class_names.get(str(class_id), f"Class {class_id}")
92
- results[species_name] = float(prob)
93
-
94
- return results
95
-
96
- # Create Gradio interface
97
- title = "🐦 Bird Species Classification"
98
- description = """
99
- Upload an image of a bird and the model will predict the species!
100
-
101
- **Model Details:**
102
- - Architecture: ConvNeXt-Base (87M parameters)
103
- - Dataset: CUB-200-2011 (200 bird species)
104
- - Test Accuracy: 83.64%
105
- - Average Per-Class Accuracy: 83.29%
106
-
107
- **Training Strategy:**
108
- - Transfer Learning with ImageNet pretrained weights
109
- - Two-phase training: Frozen backbone (40 epochs) β†’ Fine-tuning (20 epochs)
110
- - Strong regularization: Dropout (0.6, 0.5), Label smoothing (0.2)
111
- - Data augmentation: Rotation, flip, color jitter, random erasing
112
-
113
- Upload a clear image of a bird to get started!
114
- """
115
-
116
- article = """
117
- ### About This Model
118
-
119
- This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
120
- The model uses ConvNeXt-Base architecture with modern training techniques to achieve high accuracy while
121
- preventing overfitting.
122
-
123
- **Key Features:**
124
- - βœ… 200 bird species classification
125
- - βœ… State-of-the-art ConvNeXt architecture
126
- - βœ… 83.64% test accuracy
127
- - βœ… Real-time inference
128
-
129
- **Best Results:** Upload high-quality images with the bird clearly visible and centered.
130
- """
131
-
132
- examples = [
133
- # You can add example images here if you have them
134
- # ["examples/bird1.jpg"],
135
- # ["examples/bird2.jpg"],
136
- ]
137
-
138
- # Create interface
139
- iface = gr.Interface(
140
- fn=predict,
141
- inputs=gr.Image(type="pil", label="Upload Bird Image"),
142
- outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
143
- title=title,
144
- description=description,
145
- article=article,
146
- examples=examples if examples else None,
147
- theme=gr.themes.Soft(),
148
- allow_flagging="never",
149
- )
150
-
151
- # Launch the app
152
- if __name__ == "__main__":
153
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Bird Species Classification
3
+ Deployed on Hugging Face Spaces
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import transforms
10
+ from torchvision.models import convnext_base
11
+ from PIL import Image
12
+ import json
13
+
14
+ # Load class names
15
+ with open('class_names.json', 'r') as f:
16
+ class_names = json.load(f)
17
+
18
+ # Device configuration
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ # Create model architecture (same as training)
22
+ def create_model(num_classes=200):
23
+ """Create ConvNeXt model with same architecture as training"""
24
+ model = convnext_base(weights=None)
25
+
26
+ # Same classifier architecture as training
27
+ num_ftrs = model.classifier[2].in_features
28
+ model.classifier = nn.Sequential(
29
+ nn.Flatten(1),
30
+ nn.LayerNorm((num_ftrs,)),
31
+ nn.Dropout(0.6),
32
+ nn.Linear(num_ftrs, 512),
33
+ nn.GELU(),
34
+ nn.Dropout(0.5),
35
+ nn.Linear(512, num_classes)
36
+ )
37
+
38
+ return model
39
+
40
+ # Load the trained model
41
+ print("Loading model...")
42
+ model = create_model(num_classes=200)
43
+
44
+ # Load weights
45
+ import gradio as gr
46
+ import torch
47
+ import torch.nn as nn
48
+ from torchvision import transforms
49
+ from torchvision.models import convnext_base
50
+ from PIL import Image
51
+ import json
52
+ from huggingface_hub import hf_hub_download
53
+
54
+ # Load class names
55
+ with open('class_names.json', 'r') as f:
56
+ class_names = json.load(f)
57
+
58
+ # Device configuration
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+
61
+ # Create model architecture (same as training)
62
+ def create_model(num_classes=200):
63
+ """Create ConvNeXt model with same architecture as training"""
64
+ model = convnext_base(weights=None)
65
+
66
+ # Same classifier architecture as training
67
+ num_ftrs = model.classifier[2].in_features
68
+ model.classifier = nn.Sequential(
69
+ nn.Flatten(1),
70
+ nn.LayerNorm((num_ftrs,)),
71
+ nn.Dropout(0.6),
72
+ nn.Linear(num_ftrs, 512),
73
+ nn.GELU(),
74
+ nn.Dropout(0.5),
75
+ nn.Linear(512, num_classes)
76
+ )
77
+
78
+ return model
79
+
80
+ # Download model from Hugging Face Model Hub
81
+ print("Downloading model from Hugging Face Model Hub...")
82
+ model_path = hf_hub_download(
83
+ repo_id="AshProg/bird-classifier-convnext",
84
+ filename="final_model.pth"
85
+ )
86
+
87
+ # Load the trained model
88
+ model = create_model(num_classes=200)
89
+ checkpoint = torch.load(model_path, map_location=device)
90
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
91
+ model.load_state_dict(checkpoint['model_state_dict'])
92
+ if 'val_acc' in checkpoint:
93
+ val_acc = checkpoint['val_acc']
94
+ print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
95
+ else:
96
+ model.load_state_dict(checkpoint)
97
+ print("Model loaded!")
98
+
99
+ model = model.to(device)
100
+ model.eval()
101
+
102
+ # Image preprocessing (same as validation transforms)
103
+ transform = transforms.Compose([
104
+ transforms.Resize((224, 224)),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
107
+ ])
108
+
109
+ def predict(image):
110
+ """
111
+ Make prediction on uploaded image
112
+
113
+ Args:
114
+ image: PIL Image
115
+
116
+ Returns:
117
+ dict: Top 5 predictions with confidence scores
118
+ """
119
+ # Preprocess image
120
+ img_tensor = transform(image).unsqueeze(0).to(device)
121
+
122
+ # Make prediction
123
+ with torch.no_grad():
124
+ outputs = model(img_tensor)
125
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
126
+
127
+ # Get top 5 predictions
128
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
129
+
130
+ # Format results
131
+ results = {}
132
+ for i in range(5):
133
+ class_id = top5_idx[0][i].item()
134
+ prob = top5_prob[0][i].item()
135
+ species_name = class_names.get(str(class_id), f"Class {class_id}")
136
+ results[species_name] = float(prob)
137
+
138
+ return results
139
+
140
+ # Create Gradio interface
141
+ title = "🐦 Bird Species Classification"
142
+ description = """
143
+ Upload an image of a bird and the model will predict the species!
144
+
145
+ **Model Details:**
146
+ - Architecture: ConvNeXt-Base (87M parameters)
147
+ - Dataset: CUB-200-2011 (200 bird species)
148
+ - Test Accuracy: 83.64%
149
+ - Average Per-Class Accuracy: 83.29%
150
+
151
+ **Training Strategy:**
152
+ - Transfer Learning with ImageNet pretrained weights
153
+ - Two-phase training: Frozen backbone (40 epochs) β†’ Fine-tuning (20 epochs)
154
+ - Strong regularization: Dropout (0.6, 0.5), Label smoothing (0.2)
155
+ - Data augmentation: Rotation, flip, color jitter, random erasing
156
+
157
+ Upload a clear image of a bird to get started!
158
+ """
159
+
160
+ article = """
161
+ ### About This Model
162
+
163
+ This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
164
+ The model uses ConvNeXt-Base architecture with modern training techniques to achieve high accuracy while
165
+ preventing overfitting.
166
+
167
+ **Key Features:**
168
+ - βœ… 200 bird species classification
169
+ - βœ… State-of-the-art ConvNeXt architecture
170
+ - βœ… 83.64% test accuracy
171
+ - βœ… Real-time inference
172
+
173
+ **Best Results:** Upload high-quality images with the bird clearly visible and centered.
174
+ """
175
+
176
+ examples = [
177
+ # You can add example images here if you have them
178
+ # ["examples/bird1.jpg"],
179
+ # ["examples/bird2.jpg"],
180
+ ]
181
+
182
+ # Create interface
183
+ iface = gr.Interface(
184
+ fn=predict,
185
+ inputs=gr.Image(type="pil", label="Upload Bird Image"),
186
+ outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
187
+ title=title,
188
+ description=description,
189
+ article=article,
190
+ examples=examples if examples else None,
191
+ theme=gr.themes.Soft(),
192
+ allow_flagging="never",
193
+ )
194
+
195
+ # Launch the app
196
+ if __name__ == "__main__":
197
+ iface.launch()