enesmanan commited on
Commit
e1ab149
·
verified ·
1 Parent(s): 2ae907b

add deploy files

Browse files
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ import matplotlib.pyplot as plt
9
+
10
+ from models.model import EfficientNetModel, CNNModel
11
+
12
+ class AnimalClassifierApp:
13
+ def __init__(self):
14
+ """Initialize the application."""
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.labels = ["bird", "cat", "dog", "horse"]
17
+
18
+ # Image preprocessing
19
+ self.transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(
23
+ mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225]
25
+ )
26
+ ])
27
+
28
+ # Load models
29
+ self.models = self.load_models()
30
+ if not self.models:
31
+ print("Warning: No models found in checkpoints directory!")
32
+
33
+ def load_models(self):
34
+ """Load both trained models."""
35
+ models = {}
36
+
37
+ # Load EfficientNet
38
+ try:
39
+ efficientnet = EfficientNetModel(num_classes=len(self.labels))
40
+ efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
41
+ if os.path.exists(efficientnet_path):
42
+ checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
43
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
44
+ efficientnet.load_state_dict(state_dict, strict=False)
45
+ efficientnet.eval()
46
+ models['EfficientNet'] = efficientnet
47
+ print("Successfully loaded EfficientNet model")
48
+ except Exception as e:
49
+ print(f"Error loading EfficientNet model: {str(e)}")
50
+
51
+ # Load CNN
52
+ try:
53
+ cnn = CNNModel(num_classes=len(self.labels))
54
+ cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
55
+ if os.path.exists(cnn_path):
56
+ checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
57
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
58
+ cnn.load_state_dict(state_dict, strict=False)
59
+ cnn.eval()
60
+ models['CNN'] = cnn
61
+ print("Successfully loaded CNN model")
62
+ except Exception as e:
63
+ print(f"Error loading CNN model: {str(e)}")
64
+
65
+ return models
66
+
67
+ def predict(self, image: Image.Image):
68
+ """Make predictions with both models and create comparison visualizations."""
69
+ if not self.models:
70
+ return "No trained models found. Please train the models first."
71
+
72
+ # Preprocess image
73
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
74
+
75
+ # Get predictions from both models
76
+ results = {}
77
+ probabilities = {}
78
+ for model_name, model in self.models.items():
79
+ with torch.no_grad():
80
+ output = model(img_tensor)
81
+ probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
82
+ probabilities[model_name] = probs
83
+
84
+ # Get top prediction
85
+ pred_idx = np.argmax(probs)
86
+ pred_label = self.labels[pred_idx]
87
+ pred_prob = probs[pred_idx]
88
+ results[model_name] = (pred_label, pred_prob)
89
+
90
+ # Create comparison plot
91
+ fig = plt.figure(figsize=(12, 5))
92
+
93
+ # Plot for EfficientNet
94
+ if 'EfficientNet' in probabilities:
95
+ plt.subplot(1, 2, 1)
96
+ plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue')
97
+ plt.title('EfficientNet Predictions')
98
+ plt.ylim(0, 1)
99
+ plt.xticks(rotation=45)
100
+ plt.ylabel('Probability')
101
+
102
+ # Plot for CNN
103
+ if 'CNN' in probabilities:
104
+ plt.subplot(1, 2, 2)
105
+ plt.bar(self.labels, probabilities['CNN'], color='lightcoral')
106
+ plt.title('CNN Predictions')
107
+ plt.ylim(0, 1)
108
+ plt.xticks(rotation=45)
109
+ plt.ylabel('Probability')
110
+
111
+ plt.tight_layout()
112
+
113
+ # Create results text
114
+ text_results = "Model Predictions:\n\n"
115
+ for model_name, (label, prob) in results.items():
116
+ text_results += f"{model_name}:\n"
117
+ text_results += f"Top prediction: {label} ({prob:.2%})\n"
118
+ text_results += "All probabilities:\n"
119
+ for label, prob in zip(self.labels, probabilities[model_name]):
120
+ text_results += f" {label}: {prob:.2%}\n"
121
+ text_results += "\n"
122
+
123
+ return [
124
+ fig, # Probability plots
125
+ text_results # Detailed text results
126
+ ]
127
+
128
+ def create_interface(self):
129
+ """Create Gradio interface."""
130
+ return gr.Interface(
131
+ fn=self.predict,
132
+ inputs=gr.Image(type="pil"),
133
+ outputs=[
134
+ gr.Plot(label="Prediction Probabilities"),
135
+ gr.Textbox(label="Detailed Results", lines=10)
136
+ ],
137
+ title="Animal Classifier - Model Comparison",
138
+ description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
139
+ )
140
+
141
+ def main():
142
+ """Run the web application."""
143
+ app = AnimalClassifierApp()
144
+ interface = app.create_interface()
145
+ interface.launch(
146
+ server_name="0.0.0.0",
147
+ server_port=7860,
148
+ share=True
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ main()
checkpoints/cnn/cnn_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e20be6471690e71892f6b8bcc44f548cf8f876db51ca166953b1433c993e7bee
3
+ size 1557014
checkpoints/efficientnet/efficientnet_best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:754e2ec53c4f1f1c6d4a9398d38d0f415ac33a50c8794e2c6292137696ced2ee
3
+ size 48638100
model.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import timm
6
+
7
+
8
+ class BaseModel(nn.Module):
9
+ """Base model class for animal classification."""
10
+
11
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
12
+ """Get probability predictions."""
13
+ with torch.no_grad():
14
+ logits = self(x)
15
+ return F.softmax(logits, dim=1)
16
+
17
+ @classmethod
18
+ def load_from_checkpoint(
19
+ cls,
20
+ path: str,
21
+ map_location: Any = None
22
+ ) -> 'BaseModel':
23
+ """Load model from checkpoint."""
24
+ checkpoint = torch.load(path, map_location=map_location)
25
+ model = cls(num_classes=checkpoint['config']['num_classes'])
26
+ model.load_state_dict(checkpoint['model_state_dict'])
27
+ return model
28
+
29
+ def save_checkpoint(
30
+ self,
31
+ path: str,
32
+ extra_data: Dict[str, Any] = None
33
+ ) -> None:
34
+ """Save model checkpoint."""
35
+ data = {
36
+ 'model_state_dict': self.state_dict(),
37
+ 'config': {
38
+ 'num_classes': self.get_num_classes(),
39
+ 'model_type': self.__class__.__name__
40
+ }
41
+ }
42
+
43
+ if extra_data:
44
+ if 'config' in extra_data:
45
+ data['config'].update(extra_data['config'])
46
+ del extra_data['config']
47
+ data.update(extra_data)
48
+
49
+ torch.save(data, path)
50
+
51
+ def get_num_classes(self) -> int:
52
+ """Get number of output classes."""
53
+ raise NotImplementedError
54
+
55
+
56
+ class CNNModel(BaseModel):
57
+ def __init__(self, num_classes: int, input_size: int = 224):
58
+ super(CNNModel, self).__init__()
59
+
60
+ self.conv_layers = nn.Sequential(
61
+ # First block: 32 filters
62
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
63
+ nn.BatchNorm2d(32),
64
+ nn.ReLU(),
65
+ nn.MaxPool2d(2),
66
+
67
+ # Second block: 64 filters
68
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
69
+ nn.BatchNorm2d(64),
70
+ nn.ReLU(),
71
+ nn.MaxPool2d(2),
72
+
73
+ # Third block: 128 filters
74
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
75
+ nn.BatchNorm2d(128),
76
+ nn.ReLU(),
77
+ nn.MaxPool2d(2),
78
+
79
+ # Global Average Pooling
80
+ nn.AdaptiveAvgPool2d(1)
81
+ )
82
+
83
+ self.classifier = nn.Sequential(
84
+ nn.Flatten(),
85
+ nn.Dropout(0.5),
86
+ nn.Linear(128, 256),
87
+ nn.ReLU(),
88
+ nn.Dropout(0.3),
89
+ nn.Linear(256, num_classes)
90
+ )
91
+
92
+ self._initialize_weights()
93
+
94
+ def _initialize_weights(self):
95
+ """Initialize model weights."""
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
99
+ if m.bias is not None:
100
+ nn.init.constant_(m.bias, 0)
101
+ elif isinstance(m, nn.BatchNorm2d):
102
+ nn.init.constant_(m.weight, 1)
103
+ nn.init.constant_(m.bias, 0)
104
+ elif isinstance(m, nn.Linear):
105
+ nn.init.normal_(m.weight, 0, 0.01)
106
+ nn.init.constant_(m.bias, 0)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ x = self.conv_layers(x)
110
+ return self.classifier(x)
111
+
112
+ def get_num_classes(self) -> int:
113
+ return self.classifier[-1].out_features
114
+
115
+ class EfficientNetModel(BaseModel):
116
+ """EfficientNet-based model for animal classification."""
117
+
118
+ def __init__(
119
+ self,
120
+ num_classes: int,
121
+ model_name: str = "efficientnet_b0",
122
+ pretrained: bool = True
123
+ ):
124
+ super(EfficientNetModel, self).__init__()
125
+
126
+ self.base_model = timm.create_model(
127
+ model_name,
128
+ pretrained=pretrained,
129
+ num_classes=0
130
+ )
131
+
132
+ with torch.no_grad():
133
+ dummy_input = torch.randn(1, 3, 224, 224)
134
+ features = self.base_model(dummy_input)
135
+ feature_dim = features.shape[1]
136
+
137
+ # Simpler classifier structure matching the saved model
138
+ self.classifier = nn.Sequential(
139
+ nn.Dropout(0.2),
140
+ nn.Linear(feature_dim, num_classes)
141
+ )
142
+
143
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
+ features = self.base_model(x)
145
+ return self.classifier(features)
146
+
147
+ def get_num_classes(self) -> int:
148
+ return self.classifier[-1].out_features
149
+
150
+ def get_model(model_type: str, num_classes: int, **kwargs) -> BaseModel:
151
+ """Factory function to get model by type."""
152
+ models = {
153
+ 'cnn': CNNModel,
154
+ 'efficientnet': EfficientNetModel
155
+ }
156
+
157
+ if model_type not in models:
158
+ raise ValueError(f"Model type {model_type} not supported. Available models: {list(models.keys())}")
159
+
160
+ return models[model_type](num_classes=num_classes, **kwargs)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ timm==1.0.12
4
+ pillow==10.4.0
5
+ numpy==1.26.4
6
+ opencv-python==4.10.0
7
+ tqdm==4.67.1
8
+ matplotlib==3.7.5
9
+ gradio==5.9.1
10
+ wandb==0.19.1
11
+ datasets==3.2.0
12
+ scikit-learn==1.4.2