import os import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image import torchvision.transforms as transforms import matplotlib.pyplot as plt import timm class BaseModel(nn.Module): def predict(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): logits = self(x) return F.softmax(logits, dim=1) def get_num_classes(self) -> int: raise NotImplementedError class CNNModel(BaseModel): def __init__(self, num_classes: int, input_size: int = 224): super(CNNModel, self).__init__() self.conv_layers = nn.Sequential( # First block: 32 filters nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), # Second block: 64 filters nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), # Third block: 128 filters nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), # Global Average Pooling nn.AdaptiveAvgPool2d(1) ) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), nn.Linear(128, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes) ) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_layers(x) return self.classifier(x) def get_num_classes(self) -> int: return self.classifier[-1].out_features class EfficientNetModel(BaseModel): def __init__( self, num_classes: int, model_name: str = "efficientnet_b0", pretrained: bool = True ): super(EfficientNetModel, self).__init__() self.base_model = timm.create_model( model_name, pretrained=pretrained, num_classes=0 ) with torch.no_grad(): dummy_input = torch.randn(1, 3, 224, 224) features = self.base_model(dummy_input) feature_dim = features.shape[1] self.classifier = nn.Sequential( nn.Dropout(0.2), nn.Linear(feature_dim, num_classes) ) def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.base_model(x) return self.classifier(features) def get_num_classes(self) -> int: return self.classifier[-1].out_features class AnimalClassifierApp: def __init__(self): """Initialize the application.""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.labels = ["bird", "cat", "dog", "horse"] self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) self.models = self.load_models() if not self.models: print("Warning: No models found in checkpoints directory!") def load_models(self): """Load both trained models.""" models = {} try: efficientnet = EfficientNetModel(num_classes=len(self.labels)) efficientnet_path = os.path.join("checkpoints", "efficientnet", "efficientnet_best_model.pth") if os.path.exists(efficientnet_path): checkpoint = torch.load(efficientnet_path, map_location=self.device, weights_only=True) state_dict = checkpoint.get('model_state_dict', checkpoint) efficientnet.load_state_dict(state_dict, strict=False) efficientnet.eval() models['EfficientNet'] = efficientnet print("Successfully loaded EfficientNet model") except Exception as e: print(f"Error loading EfficientNet model: {str(e)}") try: cnn = CNNModel(num_classes=len(self.labels)) cnn_path = os.path.join("checkpoints", "cnn", "cnn_best_model.pth") if os.path.exists(cnn_path): checkpoint = torch.load(cnn_path, map_location=self.device, weights_only=True) state_dict = checkpoint.get('model_state_dict', checkpoint) cnn.load_state_dict(state_dict, strict=False) cnn.eval() models['CNN'] = cnn print("Successfully loaded CNN model") except Exception as e: print(f"Error loading CNN model: {str(e)}") return models def predict(self, image: Image.Image): if not self.models: return ["No trained models found. Please train the models first.", ""] # Preprocess image img_tensor = self.transform(image).unsqueeze(0).to(self.device) results = {} probabilities = {} for model_name, model in self.models.items(): with torch.no_grad(): output = model(img_tensor) probs = F.softmax(output, dim=1).squeeze().cpu().numpy() probabilities[model_name] = probs pred_idx = np.argmax(probs) pred_label = self.labels[pred_idx] pred_prob = probs[pred_idx] results[model_name] = (pred_label, pred_prob) fig = plt.figure(figsize=(12, 5)) if 'EfficientNet' in probabilities: plt.subplot(1, 2, 1) plt.bar(self.labels, probabilities['EfficientNet'], color='skyblue') plt.title('EfficientNet Predictions') plt.ylim(0, 1) plt.xticks(rotation=45) plt.ylabel('Probability') if 'CNN' in probabilities: plt.subplot(1, 2, 2) plt.bar(self.labels, probabilities['CNN'], color='lightcoral') plt.title('CNN Predictions') plt.ylim(0, 1) plt.xticks(rotation=45) plt.ylabel('Probability') plt.tight_layout() text_results = "Model Predictions:\n\n" for model_name, (label, prob) in results.items(): text_results += f"{model_name}:\n" text_results += f"Top prediction: {label} ({prob:.2%})\n" text_results += "All probabilities:\n" for label, prob in zip(self.labels, probabilities[model_name]): text_results += f" {label}: {prob:.2%}\n" text_results += "\n" return [fig, text_results] def create_interface(self): """Create Gradio interface.""" return gr.Interface( fn=self.predict, inputs=gr.Image(type="pil"), outputs=[ gr.Plot(label="Prediction Probabilities"), gr.Textbox(label="Detailed Results", lines=10) ], title="Animal Classifier - Model Comparison", description=( "Upload an image of one of these animals: Bird, Cat, Dog, or Horse.\n" "The app will compare predictions from both EfficientNet and CNN models.\n\n" "Note: For best results, ensure the animal is clearly visible in the image." ) ) def main(): app = AnimalClassifierApp() interface = app.create_interface() interface.launch() if __name__ == "__main__": main()