Spaces:
Sleeping
Sleeping
add deploy files
Browse files- app.py +152 -0
- checkpoints/cnn/cnn_best_model.pth +3 -0
- checkpoints/efficientnet/efficientnet_best_model.pth +3 -0
- model.py +160 -0
- requirements.txt +12 -0
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from models.model import EfficientNetModel, CNNModel
|
11 |
+
|
12 |
+
class AnimalClassifierApp:
|
13 |
+
def __init__(self):
|
14 |
+
"""Initialize the application."""
|
15 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
self.labels = ["bird", "cat", "dog", "horse"]
|
17 |
+
|
18 |
+
# Image preprocessing
|
19 |
+
self.transform = transforms.Compose([
|
20 |
+
transforms.Resize((224, 224)),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize(
|
23 |
+
mean=[0.485, 0.456, 0.406],
|
24 |
+
std=[0.229, 0.224, 0.225]
|
25 |
+
)
|
26 |
+
])
|
27 |
+
|
28 |
+
# Load models
|
29 |
+
self.models = self.load_models()
|
30 |
+
if not self.models:
|
31 |
+
print("Warning: No models found in checkpoints directory!")
|
32 |
+
|
33 |
+
def load_models(self):
|
34 |
+
"""Load both trained models."""
|
35 |
+
models = {}
|
36 |
+
|
37 |
+
# Load EfficientNet
|
38 |
+
try:
|
39 |
+
efficientnet = EfficientNetModel(num_classes=len(self.labels))
|
40 |
+
efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth")
|
41 |
+
if os.path.exists(efficientnet_path):
|
42 |
+
checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True)
|
43 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
44 |
+
efficientnet.load_state_dict(state_dict, strict=False)
|
45 |
+
efficientnet.eval()
|
46 |
+
models['EfficientNet'] = efficientnet
|
47 |
+
print("Successfully loaded EfficientNet model")
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Error loading EfficientNet model: {str(e)}")
|
50 |
+
|
51 |
+
# Load CNN
|
52 |
+
try:
|
53 |
+
cnn = CNNModel(num_classes=len(self.labels))
|
54 |
+
cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth")
|
55 |
+
if os.path.exists(cnn_path):
|
56 |
+
checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True)
|
57 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
58 |
+
cnn.load_state_dict(state_dict, strict=False)
|
59 |
+
cnn.eval()
|
60 |
+
models['CNN'] = cnn
|
61 |
+
print("Successfully loaded CNN model")
|
62 |
+
except Exception as e:
|
63 |
+
print(f"Error loading CNN model: {str(e)}")
|
64 |
+
|
65 |
+
return models
|
66 |
+
|
67 |
+
def predict(self, image: Image.Image):
|
68 |
+
"""Make predictions with both models and create comparison visualizations."""
|
69 |
+
if not self.models:
|
70 |
+
return "No trained models found. Please train the models first."
|
71 |
+
|
72 |
+
# Preprocess image
|
73 |
+
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
74 |
+
|
75 |
+
# Get predictions from both models
|
76 |
+
results = {}
|
77 |
+
probabilities = {}
|
78 |
+
for model_name, model in self.models.items():
|
79 |
+
with torch.no_grad():
|
80 |
+
output = model(img_tensor)
|
81 |
+
probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
|
82 |
+
probabilities[model_name] = probs
|
83 |
+
|
84 |
+
# Get top prediction
|
85 |
+
pred_idx = np.argmax(probs)
|
86 |
+
pred_label = self.labels[pred_idx]
|
87 |
+
pred_prob = probs[pred_idx]
|
88 |
+
results[model_name] = (pred_label, pred_prob)
|
89 |
+
|
90 |
+
# Create comparison plot
|
91 |
+
fig = plt.figure(figsize=(12, 5))
|
92 |
+
|
93 |
+
# Plot for EfficientNet
|
94 |
+
if 'EfficientNet' in probabilities:
|
95 |
+
plt.subplot(1, 2, 1)
|
96 |
+
plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue')
|
97 |
+
plt.title('EfficientNet Predictions')
|
98 |
+
plt.ylim(0, 1)
|
99 |
+
plt.xticks(rotation=45)
|
100 |
+
plt.ylabel('Probability')
|
101 |
+
|
102 |
+
# Plot for CNN
|
103 |
+
if 'CNN' in probabilities:
|
104 |
+
plt.subplot(1, 2, 2)
|
105 |
+
plt.bar(self.labels, probabilities['CNN'], color='lightcoral')
|
106 |
+
plt.title('CNN Predictions')
|
107 |
+
plt.ylim(0, 1)
|
108 |
+
plt.xticks(rotation=45)
|
109 |
+
plt.ylabel('Probability')
|
110 |
+
|
111 |
+
plt.tight_layout()
|
112 |
+
|
113 |
+
# Create results text
|
114 |
+
text_results = "Model Predictions:\n\n"
|
115 |
+
for model_name, (label, prob) in results.items():
|
116 |
+
text_results += f"{model_name}:\n"
|
117 |
+
text_results += f"Top prediction: {label} ({prob:.2%})\n"
|
118 |
+
text_results += "All probabilities:\n"
|
119 |
+
for label, prob in zip(self.labels, probabilities[model_name]):
|
120 |
+
text_results += f" {label}: {prob:.2%}\n"
|
121 |
+
text_results += "\n"
|
122 |
+
|
123 |
+
return [
|
124 |
+
fig, # Probability plots
|
125 |
+
text_results # Detailed text results
|
126 |
+
]
|
127 |
+
|
128 |
+
def create_interface(self):
|
129 |
+
"""Create Gradio interface."""
|
130 |
+
return gr.Interface(
|
131 |
+
fn=self.predict,
|
132 |
+
inputs=gr.Image(type="pil"),
|
133 |
+
outputs=[
|
134 |
+
gr.Plot(label="Prediction Probabilities"),
|
135 |
+
gr.Textbox(label="Detailed Results", lines=10)
|
136 |
+
],
|
137 |
+
title="Animal Classifier - Model Comparison",
|
138 |
+
description="Upload an image of an animal to see predictions from both EfficientNet and CNN models."
|
139 |
+
)
|
140 |
+
|
141 |
+
def main():
|
142 |
+
"""Run the web application."""
|
143 |
+
app = AnimalClassifierApp()
|
144 |
+
interface = app.create_interface()
|
145 |
+
interface.launch(
|
146 |
+
server_name="0.0.0.0",
|
147 |
+
server_port=7860,
|
148 |
+
share=True
|
149 |
+
)
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
main()
|
checkpoints/cnn/cnn_best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e20be6471690e71892f6b8bcc44f548cf8f876db51ca166953b1433c993e7bee
|
3 |
+
size 1557014
|
checkpoints/efficientnet/efficientnet_best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:754e2ec53c4f1f1c6d4a9398d38d0f415ac33a50c8794e2c6292137696ced2ee
|
3 |
+
size 48638100
|
model.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import timm
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(nn.Module):
|
9 |
+
"""Base model class for animal classification."""
|
10 |
+
|
11 |
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
12 |
+
"""Get probability predictions."""
|
13 |
+
with torch.no_grad():
|
14 |
+
logits = self(x)
|
15 |
+
return F.softmax(logits, dim=1)
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def load_from_checkpoint(
|
19 |
+
cls,
|
20 |
+
path: str,
|
21 |
+
map_location: Any = None
|
22 |
+
) -> 'BaseModel':
|
23 |
+
"""Load model from checkpoint."""
|
24 |
+
checkpoint = torch.load(path, map_location=map_location)
|
25 |
+
model = cls(num_classes=checkpoint['config']['num_classes'])
|
26 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
27 |
+
return model
|
28 |
+
|
29 |
+
def save_checkpoint(
|
30 |
+
self,
|
31 |
+
path: str,
|
32 |
+
extra_data: Dict[str, Any] = None
|
33 |
+
) -> None:
|
34 |
+
"""Save model checkpoint."""
|
35 |
+
data = {
|
36 |
+
'model_state_dict': self.state_dict(),
|
37 |
+
'config': {
|
38 |
+
'num_classes': self.get_num_classes(),
|
39 |
+
'model_type': self.__class__.__name__
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
if extra_data:
|
44 |
+
if 'config' in extra_data:
|
45 |
+
data['config'].update(extra_data['config'])
|
46 |
+
del extra_data['config']
|
47 |
+
data.update(extra_data)
|
48 |
+
|
49 |
+
torch.save(data, path)
|
50 |
+
|
51 |
+
def get_num_classes(self) -> int:
|
52 |
+
"""Get number of output classes."""
|
53 |
+
raise NotImplementedError
|
54 |
+
|
55 |
+
|
56 |
+
class CNNModel(BaseModel):
|
57 |
+
def __init__(self, num_classes: int, input_size: int = 224):
|
58 |
+
super(CNNModel, self).__init__()
|
59 |
+
|
60 |
+
self.conv_layers = nn.Sequential(
|
61 |
+
# First block: 32 filters
|
62 |
+
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
63 |
+
nn.BatchNorm2d(32),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.MaxPool2d(2),
|
66 |
+
|
67 |
+
# Second block: 64 filters
|
68 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
69 |
+
nn.BatchNorm2d(64),
|
70 |
+
nn.ReLU(),
|
71 |
+
nn.MaxPool2d(2),
|
72 |
+
|
73 |
+
# Third block: 128 filters
|
74 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
75 |
+
nn.BatchNorm2d(128),
|
76 |
+
nn.ReLU(),
|
77 |
+
nn.MaxPool2d(2),
|
78 |
+
|
79 |
+
# Global Average Pooling
|
80 |
+
nn.AdaptiveAvgPool2d(1)
|
81 |
+
)
|
82 |
+
|
83 |
+
self.classifier = nn.Sequential(
|
84 |
+
nn.Flatten(),
|
85 |
+
nn.Dropout(0.5),
|
86 |
+
nn.Linear(128, 256),
|
87 |
+
nn.ReLU(),
|
88 |
+
nn.Dropout(0.3),
|
89 |
+
nn.Linear(256, num_classes)
|
90 |
+
)
|
91 |
+
|
92 |
+
self._initialize_weights()
|
93 |
+
|
94 |
+
def _initialize_weights(self):
|
95 |
+
"""Initialize model weights."""
|
96 |
+
for m in self.modules():
|
97 |
+
if isinstance(m, nn.Conv2d):
|
98 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
99 |
+
if m.bias is not None:
|
100 |
+
nn.init.constant_(m.bias, 0)
|
101 |
+
elif isinstance(m, nn.BatchNorm2d):
|
102 |
+
nn.init.constant_(m.weight, 1)
|
103 |
+
nn.init.constant_(m.bias, 0)
|
104 |
+
elif isinstance(m, nn.Linear):
|
105 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
106 |
+
nn.init.constant_(m.bias, 0)
|
107 |
+
|
108 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
109 |
+
x = self.conv_layers(x)
|
110 |
+
return self.classifier(x)
|
111 |
+
|
112 |
+
def get_num_classes(self) -> int:
|
113 |
+
return self.classifier[-1].out_features
|
114 |
+
|
115 |
+
class EfficientNetModel(BaseModel):
|
116 |
+
"""EfficientNet-based model for animal classification."""
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
num_classes: int,
|
121 |
+
model_name: str = "efficientnet_b0",
|
122 |
+
pretrained: bool = True
|
123 |
+
):
|
124 |
+
super(EfficientNetModel, self).__init__()
|
125 |
+
|
126 |
+
self.base_model = timm.create_model(
|
127 |
+
model_name,
|
128 |
+
pretrained=pretrained,
|
129 |
+
num_classes=0
|
130 |
+
)
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
134 |
+
features = self.base_model(dummy_input)
|
135 |
+
feature_dim = features.shape[1]
|
136 |
+
|
137 |
+
# Simpler classifier structure matching the saved model
|
138 |
+
self.classifier = nn.Sequential(
|
139 |
+
nn.Dropout(0.2),
|
140 |
+
nn.Linear(feature_dim, num_classes)
|
141 |
+
)
|
142 |
+
|
143 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
144 |
+
features = self.base_model(x)
|
145 |
+
return self.classifier(features)
|
146 |
+
|
147 |
+
def get_num_classes(self) -> int:
|
148 |
+
return self.classifier[-1].out_features
|
149 |
+
|
150 |
+
def get_model(model_type: str, num_classes: int, **kwargs) -> BaseModel:
|
151 |
+
"""Factory function to get model by type."""
|
152 |
+
models = {
|
153 |
+
'cnn': CNNModel,
|
154 |
+
'efficientnet': EfficientNetModel
|
155 |
+
}
|
156 |
+
|
157 |
+
if model_type not in models:
|
158 |
+
raise ValueError(f"Model type {model_type} not supported. Available models: {list(models.keys())}")
|
159 |
+
|
160 |
+
return models[model_type](num_classes=num_classes, **kwargs)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
timm==1.0.12
|
4 |
+
pillow==10.4.0
|
5 |
+
numpy==1.26.4
|
6 |
+
opencv-python==4.10.0
|
7 |
+
tqdm==4.67.1
|
8 |
+
matplotlib==3.7.5
|
9 |
+
gradio==5.9.1
|
10 |
+
wandb==0.19.1
|
11 |
+
datasets==3.2.0
|
12 |
+
scikit-learn==1.4.2
|