AshProg commited on
Commit
fd62084
Β·
verified Β·
1 Parent(s): 1209574

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Bird Species Classification
3
+ Deployed on Hugging Face Spaces
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import transforms
10
+ from torchvision.models import convnext_base
11
+ from PIL import Image
12
+ import json
13
+
14
+ # Load class names
15
+ with open('class_names.json', 'r') as f:
16
+ class_names = json.load(f)
17
+
18
+ # Device configuration
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ # Create model architecture (same as training)
22
+ def create_model(num_classes=200):
23
+ """Create ConvNeXt model with same architecture as training"""
24
+ model = convnext_base(weights=None)
25
+
26
+ # Same classifier architecture as training
27
+ num_ftrs = model.classifier[2].in_features
28
+ model.classifier = nn.Sequential(
29
+ nn.Flatten(1),
30
+ nn.LayerNorm((num_ftrs,)),
31
+ nn.Dropout(0.6),
32
+ nn.Linear(num_ftrs, 512),
33
+ nn.GELU(),
34
+ nn.Dropout(0.5),
35
+ nn.Linear(512, num_classes)
36
+ )
37
+
38
+ return model
39
+
40
+ # Load the trained model
41
+ print("Loading model...")
42
+ model = create_model(num_classes=200)
43
+
44
+ # Load weights
45
+ checkpoint = torch.load('models/final_model.pth', map_location=device)
46
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
47
+ model.load_state_dict(checkpoint['model_state_dict'])
48
+ if 'val_acc' in checkpoint:
49
+ val_acc = checkpoint['val_acc']
50
+ print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
51
+ else:
52
+ model.load_state_dict(checkpoint)
53
+ print("Model loaded!")
54
+
55
+ model = model.to(device)
56
+ model.eval()
57
+
58
+ # Image preprocessing (same as validation transforms)
59
+ transform = transforms.Compose([
60
+ transforms.Resize((224, 224)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
63
+ ])
64
+
65
+ def predict(image):
66
+ """
67
+ Make prediction on uploaded image
68
+
69
+ Args:
70
+ image: PIL Image
71
+
72
+ Returns:
73
+ dict: Top 5 predictions with confidence scores
74
+ """
75
+ # Preprocess image
76
+ img_tensor = transform(image).unsqueeze(0).to(device)
77
+
78
+ # Make prediction
79
+ with torch.no_grad():
80
+ outputs = model(img_tensor)
81
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
82
+
83
+ # Get top 5 predictions
84
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
85
+
86
+ # Format results
87
+ results = {}
88
+ for i in range(5):
89
+ class_id = top5_idx[0][i].item()
90
+ prob = top5_prob[0][i].item()
91
+ species_name = class_names.get(str(class_id), f"Class {class_id}")
92
+ results[species_name] = float(prob)
93
+
94
+ return results
95
+
96
+ # Create Gradio interface
97
+ title = "🐦 Bird Species Classification"
98
+ description = """
99
+ Upload an image of a bird and the model will predict the species!
100
+
101
+ **Model Details:**
102
+ - Architecture: ConvNeXt-Base (87M parameters)
103
+ - Dataset: CUB-200-2011 (200 bird species)
104
+ - Test Accuracy: 83.64%
105
+ - Average Per-Class Accuracy: 83.29%
106
+
107
+ **Training Strategy:**
108
+ - Transfer Learning with ImageNet pretrained weights
109
+ - Two-phase training: Frozen backbone (40 epochs) β†’ Fine-tuning (20 epochs)
110
+ - Strong regularization: Dropout (0.6, 0.5), Label smoothing (0.2)
111
+ - Data augmentation: Rotation, flip, color jitter, random erasing
112
+
113
+ Upload a clear image of a bird to get started!
114
+ """
115
+
116
+ article = """
117
+ ### About This Model
118
+
119
+ This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
120
+ The model uses ConvNeXt-Base architecture with modern training techniques to achieve high accuracy while
121
+ preventing overfitting.
122
+
123
+ **Key Features:**
124
+ - βœ… 200 bird species classification
125
+ - βœ… State-of-the-art ConvNeXt architecture
126
+ - βœ… 83.64% test accuracy
127
+ - βœ… Real-time inference
128
+
129
+ **Best Results:** Upload high-quality images with the bird clearly visible and centered.
130
+ """
131
+
132
+ examples = [
133
+ # You can add example images here if you have them
134
+ # ["examples/bird1.jpg"],
135
+ # ["examples/bird2.jpg"],
136
+ ]
137
+
138
+ # Create interface
139
+ iface = gr.Interface(
140
+ fn=predict,
141
+ inputs=gr.Image(type="pil", label="Upload Bird Image"),
142
+ outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
143
+ title=title,
144
+ description=description,
145
+ article=article,
146
+ examples=examples if examples else None,
147
+ theme=gr.themes.Soft(),
148
+ allow_flagging="never",
149
+ )
150
+
151
+ # Launch the app
152
+ if __name__ == "__main__":
153
+ iface.launch()